mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-23 10:57:44 +08:00
refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages
This commit is contained in:
64
internal/assistantturn/stream.go
Normal file
64
internal/assistantturn/stream.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package assistantturn
|
||||
|
||||
import (
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
type StreamEventType string
|
||||
|
||||
const (
|
||||
StreamEventTextDelta StreamEventType = "text_delta"
|
||||
StreamEventThinkingDelta StreamEventType = "thinking_delta"
|
||||
StreamEventToolCall StreamEventType = "tool_call"
|
||||
StreamEventDone StreamEventType = "done"
|
||||
StreamEventError StreamEventType = "error"
|
||||
StreamEventPing StreamEventType = "ping"
|
||||
)
|
||||
|
||||
type StreamEvent struct {
|
||||
Type StreamEventType
|
||||
Text string
|
||||
Thinking string
|
||||
ToolCall any
|
||||
Error *OutputError
|
||||
Usage *Usage
|
||||
}
|
||||
|
||||
type Accumulator struct {
|
||||
inner shared.StreamAccumulator
|
||||
}
|
||||
|
||||
type AccumulatorOptions struct {
|
||||
ThinkingEnabled bool
|
||||
SearchEnabled bool
|
||||
StripReferenceMarkers bool
|
||||
}
|
||||
|
||||
func NewAccumulator(opts AccumulatorOptions) *Accumulator {
|
||||
return &Accumulator{
|
||||
inner: shared.StreamAccumulator{
|
||||
ThinkingEnabled: opts.ThinkingEnabled,
|
||||
SearchEnabled: opts.SearchEnabled,
|
||||
StripReferenceMarkers: opts.StripReferenceMarkers,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Accumulator) Apply(parsed sse.LineResult) shared.StreamAccumulatorResult {
|
||||
if a == nil {
|
||||
return shared.StreamAccumulatorResult{}
|
||||
}
|
||||
return a.inner.Apply(parsed)
|
||||
}
|
||||
|
||||
func (a *Accumulator) Snapshot() (rawText, text, rawThinking, thinking, detectionThinking string) {
|
||||
if a == nil {
|
||||
return "", "", "", "", ""
|
||||
}
|
||||
return a.inner.RawText.String(),
|
||||
a.inner.Text.String(),
|
||||
a.inner.RawThinking.String(),
|
||||
a.inner.Thinking.String(),
|
||||
a.inner.ToolDetectionThinking.String()
|
||||
}
|
||||
227
internal/assistantturn/turn.go
Normal file
227
internal/assistantturn/turn.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package assistantturn
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type StopReason string
|
||||
|
||||
const (
|
||||
StopReasonStop StopReason = "stop"
|
||||
StopReasonToolCalls StopReason = "tool_calls"
|
||||
StopReasonContentFilter StopReason = "content_filter"
|
||||
StopReasonError StopReason = "error"
|
||||
)
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
ReasoningTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
type OutputError struct {
|
||||
Status int
|
||||
Message string
|
||||
Code string
|
||||
}
|
||||
|
||||
type Turn struct {
|
||||
Model string
|
||||
Prompt string
|
||||
RawText string
|
||||
RawThinking string
|
||||
DetectionThinking string
|
||||
Text string
|
||||
Thinking string
|
||||
ToolCalls []toolcall.ParsedToolCall
|
||||
ParsedToolCalls toolcall.ToolCallParseResult
|
||||
CitationLinks map[int]string
|
||||
ContentFilter bool
|
||||
ResponseMessageID int
|
||||
StopReason StopReason
|
||||
Usage Usage
|
||||
Error *OutputError
|
||||
}
|
||||
|
||||
type BuildOptions struct {
|
||||
Model string
|
||||
Prompt string
|
||||
RefFileTokens int
|
||||
SearchEnabled bool
|
||||
StripReferenceMarkers bool
|
||||
ToolNames []string
|
||||
ToolsRaw any
|
||||
ToolChoice promptcompat.ToolChoicePolicy
|
||||
}
|
||||
|
||||
type StreamSnapshot struct {
|
||||
RawText string
|
||||
VisibleText string
|
||||
RawThinking string
|
||||
VisibleThinking string
|
||||
DetectionThinking string
|
||||
ContentFilter bool
|
||||
CitationLinks map[int]string
|
||||
ResponseMessageID int
|
||||
AlreadyEmittedCalls bool
|
||||
AdditionalToolCalls []toolcall.ParsedToolCall
|
||||
AlreadyEmittedToolRaw bool
|
||||
}
|
||||
|
||||
func BuildTurnFromCollected(result sse.CollectResult, opts BuildOptions) Turn {
|
||||
thinking := shared.CleanVisibleOutput(result.Thinking, opts.StripReferenceMarkers)
|
||||
text := shared.CleanVisibleOutput(result.Text, opts.StripReferenceMarkers)
|
||||
if opts.SearchEnabled {
|
||||
text = shared.ReplaceCitationMarkersWithLinks(text, result.CitationLinks)
|
||||
}
|
||||
|
||||
parsed := shared.DetectAssistantToolCalls(result.Text, text, result.Thinking, result.ToolDetectionThinking, opts.ToolNames)
|
||||
calls := toolcall.NormalizeParsedToolCallsForSchemas(parsed.Calls, opts.ToolsRaw)
|
||||
parsed.Calls = calls
|
||||
|
||||
stopReason := StopReasonStop
|
||||
if result.ContentFilter {
|
||||
stopReason = StopReasonContentFilter
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
stopReason = StopReasonToolCalls
|
||||
}
|
||||
|
||||
turn := Turn{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
RawText: result.Text,
|
||||
RawThinking: result.Thinking,
|
||||
DetectionThinking: result.ToolDetectionThinking,
|
||||
Text: text,
|
||||
Thinking: thinking,
|
||||
ToolCalls: calls,
|
||||
ParsedToolCalls: parsed,
|
||||
CitationLinks: result.CitationLinks,
|
||||
ContentFilter: result.ContentFilter,
|
||||
ResponseMessageID: result.ResponseMessageID,
|
||||
StopReason: stopReason,
|
||||
}
|
||||
turn.Usage = BuildUsage(opts.Model, opts.Prompt, thinking, text, opts.RefFileTokens)
|
||||
turn.Error = ValidateTurn(turn, opts.ToolChoice)
|
||||
if turn.Error != nil {
|
||||
turn.StopReason = StopReasonError
|
||||
}
|
||||
return turn
|
||||
}
|
||||
|
||||
func BuildTurnFromStreamSnapshot(snapshot StreamSnapshot, opts BuildOptions) Turn {
|
||||
thinking := shared.CleanVisibleOutput(snapshot.VisibleThinking, opts.StripReferenceMarkers)
|
||||
text := shared.CleanVisibleOutput(snapshot.VisibleText, opts.StripReferenceMarkers)
|
||||
if opts.SearchEnabled {
|
||||
text = shared.ReplaceCitationMarkersWithLinks(text, snapshot.CitationLinks)
|
||||
}
|
||||
|
||||
parsed := shared.DetectAssistantToolCalls(snapshot.RawText, text, snapshot.RawThinking, snapshot.DetectionThinking, opts.ToolNames)
|
||||
calls := parsed.Calls
|
||||
if len(calls) == 0 && len(snapshot.AdditionalToolCalls) > 0 {
|
||||
calls = snapshot.AdditionalToolCalls
|
||||
}
|
||||
calls = toolcall.NormalizeParsedToolCallsForSchemas(calls, opts.ToolsRaw)
|
||||
parsed.Calls = calls
|
||||
|
||||
stopReason := StopReasonStop
|
||||
if snapshot.ContentFilter {
|
||||
stopReason = StopReasonContentFilter
|
||||
}
|
||||
if len(calls) > 0 || snapshot.AlreadyEmittedCalls || snapshot.AlreadyEmittedToolRaw {
|
||||
stopReason = StopReasonToolCalls
|
||||
}
|
||||
|
||||
turn := Turn{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
RawText: snapshot.RawText,
|
||||
RawThinking: snapshot.RawThinking,
|
||||
DetectionThinking: snapshot.DetectionThinking,
|
||||
Text: text,
|
||||
Thinking: thinking,
|
||||
ToolCalls: calls,
|
||||
ParsedToolCalls: parsed,
|
||||
CitationLinks: snapshot.CitationLinks,
|
||||
ContentFilter: snapshot.ContentFilter,
|
||||
ResponseMessageID: snapshot.ResponseMessageID,
|
||||
StopReason: stopReason,
|
||||
}
|
||||
turn.Usage = BuildUsage(opts.Model, opts.Prompt, thinking, text, opts.RefFileTokens)
|
||||
if !snapshot.AlreadyEmittedCalls && !snapshot.AlreadyEmittedToolRaw {
|
||||
turn.Error = ValidateTurn(turn, opts.ToolChoice)
|
||||
}
|
||||
if turn.Error != nil && len(calls) == 0 {
|
||||
turn.StopReason = StopReasonError
|
||||
}
|
||||
return turn
|
||||
}
|
||||
|
||||
func BuildUsage(model, prompt, thinking, text string, refFileTokens int) Usage {
|
||||
inputTokens := util.CountPromptTokens(prompt, model) + refFileTokens
|
||||
reasoningTokens := util.CountOutputTokens(thinking, model)
|
||||
outputTokens := reasoningTokens + util.CountOutputTokens(text, model)
|
||||
return Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
ReasoningTokens: reasoningTokens,
|
||||
TotalTokens: inputTokens + outputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateTurn(turn Turn, policy promptcompat.ToolChoicePolicy) *OutputError {
|
||||
if policy.IsRequired() && len(turn.ToolCalls) == 0 {
|
||||
return &OutputError{
|
||||
Status: http.StatusUnprocessableEntity,
|
||||
Message: "tool_choice requires at least one valid tool call.",
|
||||
Code: "tool_choice_violation",
|
||||
}
|
||||
}
|
||||
if len(turn.ToolCalls) > 0 {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(turn.Text) != "" {
|
||||
return nil
|
||||
}
|
||||
status, message, code := UpstreamEmptyOutputDetail(turn.ContentFilter, turn.Text, turn.Thinking)
|
||||
return &OutputError{Status: status, Message: message, Code: code}
|
||||
}
|
||||
|
||||
func UpstreamEmptyOutputDetail(contentFilter bool, text, thinking string) (int, string, string) {
|
||||
_ = text
|
||||
if contentFilter {
|
||||
return http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter"
|
||||
}
|
||||
if strings.TrimSpace(thinking) != "" {
|
||||
return http.StatusTooManyRequests, "Upstream account hit a rate limit and returned reasoning without visible output.", "upstream_empty_output"
|
||||
}
|
||||
return http.StatusTooManyRequests, "Upstream account hit a rate limit and returned empty output.", "upstream_empty_output"
|
||||
}
|
||||
|
||||
func ShouldRetryEmptyOutput(turn Turn, attempts, maxAttempts int) bool {
|
||||
return attempts < maxAttempts &&
|
||||
!turn.ContentFilter &&
|
||||
len(turn.ToolCalls) == 0 &&
|
||||
strings.TrimSpace(turn.Text) == "" &&
|
||||
strings.TrimSpace(turn.Thinking) == ""
|
||||
}
|
||||
|
||||
func FinishReason(turn Turn) string {
|
||||
switch turn.StopReason {
|
||||
case StopReasonToolCalls:
|
||||
return "tool_calls"
|
||||
case StopReasonContentFilter:
|
||||
return "content_filter"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
100
internal/assistantturn/turn_test.go
Normal file
100
internal/assistantturn/turn_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package assistantturn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func TestBuildTurnFromCollectedTextCitation(t *testing.T) {
|
||||
turn := BuildTurnFromCollected(sse.CollectResult{
|
||||
Text: "See [citation:1]",
|
||||
CitationLinks: map[int]string{1: "https://example.com"},
|
||||
}, BuildOptions{Model: "deepseek-v4-flash", Prompt: "prompt", SearchEnabled: true, StripReferenceMarkers: true})
|
||||
if turn.Text != "See [1](https://example.com)" {
|
||||
t.Fatalf("text mismatch: %q", turn.Text)
|
||||
}
|
||||
if turn.StopReason != StopReasonStop {
|
||||
t.Fatalf("stop reason mismatch: %q", turn.StopReason)
|
||||
}
|
||||
if turn.Error != nil {
|
||||
t.Fatalf("unexpected error: %#v", turn.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTurnFromCollectedToolCall(t *testing.T) {
|
||||
turn := BuildTurnFromCollected(sse.CollectResult{
|
||||
Text: `<tool_calls><invoke name="Write"><parameter name="content">{"x":1}</parameter></invoke></tool_calls>`,
|
||||
}, BuildOptions{
|
||||
ToolNames: []string{"Write"},
|
||||
ToolsRaw: []any{map[string]any{
|
||||
"name": "Write",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if len(turn.ToolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %d", len(turn.ToolCalls))
|
||||
}
|
||||
if turn.StopReason != StopReasonToolCalls {
|
||||
t.Fatalf("stop reason mismatch: %q", turn.StopReason)
|
||||
}
|
||||
if _, ok := turn.ToolCalls[0].Input["content"].(string); !ok {
|
||||
t.Fatalf("expected content coerced to string, got %#v", turn.ToolCalls[0].Input["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTurnFromCollectedThinkingOnlyIsEmptyOutput(t *testing.T) {
|
||||
turn := BuildTurnFromCollected(sse.CollectResult{Thinking: "hidden"}, BuildOptions{})
|
||||
if turn.Error == nil || turn.Error.Code != "upstream_empty_output" {
|
||||
t.Fatalf("expected empty output error, got %#v", turn.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTurnFromCollectedToolChoiceRequired(t *testing.T) {
|
||||
turn := BuildTurnFromCollected(sse.CollectResult{Text: "hello"}, BuildOptions{
|
||||
ToolChoice: promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired},
|
||||
})
|
||||
if turn.Error == nil || turn.Error.Code != "tool_choice_violation" {
|
||||
t.Fatalf("expected tool choice violation, got %#v", turn.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTurnFromStreamSnapshotUsesVisibleTextAndRawToolDetection(t *testing.T) {
|
||||
turn := BuildTurnFromStreamSnapshot(StreamSnapshot{
|
||||
RawText: `<tool_calls><invoke name="Write"><parameter name="content">{"x":1}</parameter></invoke></tool_calls>`,
|
||||
VisibleText: "",
|
||||
}, BuildOptions{
|
||||
ToolNames: []string{"Write"},
|
||||
ToolsRaw: []any{map[string]any{
|
||||
"name": "Write",
|
||||
"schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if len(turn.ToolCalls) != 1 {
|
||||
t.Fatalf("expected stream snapshot tool call, got %d", len(turn.ToolCalls))
|
||||
}
|
||||
if _, ok := turn.ToolCalls[0].Input["content"].(string); !ok {
|
||||
t.Fatalf("expected stream snapshot schema coercion, got %#v", turn.ToolCalls[0].Input["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTurnFromStreamSnapshotAlreadyEmittedToolAvoidsEmptyError(t *testing.T) {
|
||||
turn := BuildTurnFromStreamSnapshot(StreamSnapshot{AlreadyEmittedCalls: true}, BuildOptions{})
|
||||
if turn.Error != nil {
|
||||
t.Fatalf("unexpected empty-output error after emitted tool call: %#v", turn.Error)
|
||||
}
|
||||
if turn.StopReason != StopReasonToolCalls {
|
||||
t.Fatalf("stop reason mismatch: %q", turn.StopReason)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user