refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages

This commit is contained in:
CJACK
2026-05-02 23:28:43 +08:00
parent eccd8c957b
commit dc5bffdf89
24 changed files with 1215 additions and 254 deletions

View File

@@ -0,0 +1,64 @@
package assistantturn
import (
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/sse"
)
type StreamEventType string
const (
StreamEventTextDelta StreamEventType = "text_delta"
StreamEventThinkingDelta StreamEventType = "thinking_delta"
StreamEventToolCall StreamEventType = "tool_call"
StreamEventDone StreamEventType = "done"
StreamEventError StreamEventType = "error"
StreamEventPing StreamEventType = "ping"
)
type StreamEvent struct {
Type StreamEventType
Text string
Thinking string
ToolCall any
Error *OutputError
Usage *Usage
}
type Accumulator struct {
inner shared.StreamAccumulator
}
type AccumulatorOptions struct {
ThinkingEnabled bool
SearchEnabled bool
StripReferenceMarkers bool
}
func NewAccumulator(opts AccumulatorOptions) *Accumulator {
return &Accumulator{
inner: shared.StreamAccumulator{
ThinkingEnabled: opts.ThinkingEnabled,
SearchEnabled: opts.SearchEnabled,
StripReferenceMarkers: opts.StripReferenceMarkers,
},
}
}
func (a *Accumulator) Apply(parsed sse.LineResult) shared.StreamAccumulatorResult {
if a == nil {
return shared.StreamAccumulatorResult{}
}
return a.inner.Apply(parsed)
}
func (a *Accumulator) Snapshot() (rawText, text, rawThinking, thinking, detectionThinking string) {
if a == nil {
return "", "", "", "", ""
}
return a.inner.RawText.String(),
a.inner.Text.String(),
a.inner.RawThinking.String(),
a.inner.Thinking.String(),
a.inner.ToolDetectionThinking.String()
}

View File

@@ -0,0 +1,227 @@
package assistantturn
import (
"net/http"
"strings"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
"ds2api/internal/toolcall"
"ds2api/internal/util"
)
type StopReason string
const (
StopReasonStop StopReason = "stop"
StopReasonToolCalls StopReason = "tool_calls"
StopReasonContentFilter StopReason = "content_filter"
StopReasonError StopReason = "error"
)
type Usage struct {
InputTokens int
OutputTokens int
ReasoningTokens int
TotalTokens int
}
type OutputError struct {
Status int
Message string
Code string
}
type Turn struct {
Model string
Prompt string
RawText string
RawThinking string
DetectionThinking string
Text string
Thinking string
ToolCalls []toolcall.ParsedToolCall
ParsedToolCalls toolcall.ToolCallParseResult
CitationLinks map[int]string
ContentFilter bool
ResponseMessageID int
StopReason StopReason
Usage Usage
Error *OutputError
}
type BuildOptions struct {
Model string
Prompt string
RefFileTokens int
SearchEnabled bool
StripReferenceMarkers bool
ToolNames []string
ToolsRaw any
ToolChoice promptcompat.ToolChoicePolicy
}
type StreamSnapshot struct {
RawText string
VisibleText string
RawThinking string
VisibleThinking string
DetectionThinking string
ContentFilter bool
CitationLinks map[int]string
ResponseMessageID int
AlreadyEmittedCalls bool
AdditionalToolCalls []toolcall.ParsedToolCall
AlreadyEmittedToolRaw bool
}
func BuildTurnFromCollected(result sse.CollectResult, opts BuildOptions) Turn {
thinking := shared.CleanVisibleOutput(result.Thinking, opts.StripReferenceMarkers)
text := shared.CleanVisibleOutput(result.Text, opts.StripReferenceMarkers)
if opts.SearchEnabled {
text = shared.ReplaceCitationMarkersWithLinks(text, result.CitationLinks)
}
parsed := shared.DetectAssistantToolCalls(result.Text, text, result.Thinking, result.ToolDetectionThinking, opts.ToolNames)
calls := toolcall.NormalizeParsedToolCallsForSchemas(parsed.Calls, opts.ToolsRaw)
parsed.Calls = calls
stopReason := StopReasonStop
if result.ContentFilter {
stopReason = StopReasonContentFilter
}
if len(calls) > 0 {
stopReason = StopReasonToolCalls
}
turn := Turn{
Model: opts.Model,
Prompt: opts.Prompt,
RawText: result.Text,
RawThinking: result.Thinking,
DetectionThinking: result.ToolDetectionThinking,
Text: text,
Thinking: thinking,
ToolCalls: calls,
ParsedToolCalls: parsed,
CitationLinks: result.CitationLinks,
ContentFilter: result.ContentFilter,
ResponseMessageID: result.ResponseMessageID,
StopReason: stopReason,
}
turn.Usage = BuildUsage(opts.Model, opts.Prompt, thinking, text, opts.RefFileTokens)
turn.Error = ValidateTurn(turn, opts.ToolChoice)
if turn.Error != nil {
turn.StopReason = StopReasonError
}
return turn
}
func BuildTurnFromStreamSnapshot(snapshot StreamSnapshot, opts BuildOptions) Turn {
thinking := shared.CleanVisibleOutput(snapshot.VisibleThinking, opts.StripReferenceMarkers)
text := shared.CleanVisibleOutput(snapshot.VisibleText, opts.StripReferenceMarkers)
if opts.SearchEnabled {
text = shared.ReplaceCitationMarkersWithLinks(text, snapshot.CitationLinks)
}
parsed := shared.DetectAssistantToolCalls(snapshot.RawText, text, snapshot.RawThinking, snapshot.DetectionThinking, opts.ToolNames)
calls := parsed.Calls
if len(calls) == 0 && len(snapshot.AdditionalToolCalls) > 0 {
calls = snapshot.AdditionalToolCalls
}
calls = toolcall.NormalizeParsedToolCallsForSchemas(calls, opts.ToolsRaw)
parsed.Calls = calls
stopReason := StopReasonStop
if snapshot.ContentFilter {
stopReason = StopReasonContentFilter
}
if len(calls) > 0 || snapshot.AlreadyEmittedCalls || snapshot.AlreadyEmittedToolRaw {
stopReason = StopReasonToolCalls
}
turn := Turn{
Model: opts.Model,
Prompt: opts.Prompt,
RawText: snapshot.RawText,
RawThinking: snapshot.RawThinking,
DetectionThinking: snapshot.DetectionThinking,
Text: text,
Thinking: thinking,
ToolCalls: calls,
ParsedToolCalls: parsed,
CitationLinks: snapshot.CitationLinks,
ContentFilter: snapshot.ContentFilter,
ResponseMessageID: snapshot.ResponseMessageID,
StopReason: stopReason,
}
turn.Usage = BuildUsage(opts.Model, opts.Prompt, thinking, text, opts.RefFileTokens)
if !snapshot.AlreadyEmittedCalls && !snapshot.AlreadyEmittedToolRaw {
turn.Error = ValidateTurn(turn, opts.ToolChoice)
}
if turn.Error != nil && len(calls) == 0 {
turn.StopReason = StopReasonError
}
return turn
}
func BuildUsage(model, prompt, thinking, text string, refFileTokens int) Usage {
inputTokens := util.CountPromptTokens(prompt, model) + refFileTokens
reasoningTokens := util.CountOutputTokens(thinking, model)
outputTokens := reasoningTokens + util.CountOutputTokens(text, model)
return Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
ReasoningTokens: reasoningTokens,
TotalTokens: inputTokens + outputTokens,
}
}
func ValidateTurn(turn Turn, policy promptcompat.ToolChoicePolicy) *OutputError {
if policy.IsRequired() && len(turn.ToolCalls) == 0 {
return &OutputError{
Status: http.StatusUnprocessableEntity,
Message: "tool_choice requires at least one valid tool call.",
Code: "tool_choice_violation",
}
}
if len(turn.ToolCalls) > 0 {
return nil
}
if strings.TrimSpace(turn.Text) != "" {
return nil
}
status, message, code := UpstreamEmptyOutputDetail(turn.ContentFilter, turn.Text, turn.Thinking)
return &OutputError{Status: status, Message: message, Code: code}
}
func UpstreamEmptyOutputDetail(contentFilter bool, text, thinking string) (int, string, string) {
_ = text
if contentFilter {
return http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter"
}
if strings.TrimSpace(thinking) != "" {
return http.StatusTooManyRequests, "Upstream account hit a rate limit and returned reasoning without visible output.", "upstream_empty_output"
}
return http.StatusTooManyRequests, "Upstream account hit a rate limit and returned empty output.", "upstream_empty_output"
}
func ShouldRetryEmptyOutput(turn Turn, attempts, maxAttempts int) bool {
return attempts < maxAttempts &&
!turn.ContentFilter &&
len(turn.ToolCalls) == 0 &&
strings.TrimSpace(turn.Text) == "" &&
strings.TrimSpace(turn.Thinking) == ""
}
func FinishReason(turn Turn) string {
switch turn.StopReason {
case StopReasonToolCalls:
return "tool_calls"
case StopReasonContentFilter:
return "content_filter"
default:
return "stop"
}
}

View File

@@ -0,0 +1,100 @@
package assistantturn
import (
"testing"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
)
func TestBuildTurnFromCollectedTextCitation(t *testing.T) {
turn := BuildTurnFromCollected(sse.CollectResult{
Text: "See [citation:1]",
CitationLinks: map[int]string{1: "https://example.com"},
}, BuildOptions{Model: "deepseek-v4-flash", Prompt: "prompt", SearchEnabled: true, StripReferenceMarkers: true})
if turn.Text != "See [1](https://example.com)" {
t.Fatalf("text mismatch: %q", turn.Text)
}
if turn.StopReason != StopReasonStop {
t.Fatalf("stop reason mismatch: %q", turn.StopReason)
}
if turn.Error != nil {
t.Fatalf("unexpected error: %#v", turn.Error)
}
}
func TestBuildTurnFromCollectedToolCall(t *testing.T) {
turn := BuildTurnFromCollected(sse.CollectResult{
Text: `<tool_calls><invoke name="Write"><parameter name="content">{"x":1}</parameter></invoke></tool_calls>`,
}, BuildOptions{
ToolNames: []string{"Write"},
ToolsRaw: []any{map[string]any{
"name": "Write",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
}},
})
if len(turn.ToolCalls) != 1 {
t.Fatalf("expected one tool call, got %d", len(turn.ToolCalls))
}
if turn.StopReason != StopReasonToolCalls {
t.Fatalf("stop reason mismatch: %q", turn.StopReason)
}
if _, ok := turn.ToolCalls[0].Input["content"].(string); !ok {
t.Fatalf("expected content coerced to string, got %#v", turn.ToolCalls[0].Input["content"])
}
}
func TestBuildTurnFromCollectedThinkingOnlyIsEmptyOutput(t *testing.T) {
turn := BuildTurnFromCollected(sse.CollectResult{Thinking: "hidden"}, BuildOptions{})
if turn.Error == nil || turn.Error.Code != "upstream_empty_output" {
t.Fatalf("expected empty output error, got %#v", turn.Error)
}
}
func TestBuildTurnFromCollectedToolChoiceRequired(t *testing.T) {
turn := BuildTurnFromCollected(sse.CollectResult{Text: "hello"}, BuildOptions{
ToolChoice: promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired},
})
if turn.Error == nil || turn.Error.Code != "tool_choice_violation" {
t.Fatalf("expected tool choice violation, got %#v", turn.Error)
}
}
func TestBuildTurnFromStreamSnapshotUsesVisibleTextAndRawToolDetection(t *testing.T) {
turn := BuildTurnFromStreamSnapshot(StreamSnapshot{
RawText: `<tool_calls><invoke name="Write"><parameter name="content">{"x":1}</parameter></invoke></tool_calls>`,
VisibleText: "",
}, BuildOptions{
ToolNames: []string{"Write"},
ToolsRaw: []any{map[string]any{
"name": "Write",
"schema": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
}},
})
if len(turn.ToolCalls) != 1 {
t.Fatalf("expected stream snapshot tool call, got %d", len(turn.ToolCalls))
}
if _, ok := turn.ToolCalls[0].Input["content"].(string); !ok {
t.Fatalf("expected stream snapshot schema coercion, got %#v", turn.ToolCalls[0].Input["content"])
}
}
func TestBuildTurnFromStreamSnapshotAlreadyEmittedToolAvoidsEmptyError(t *testing.T) {
turn := BuildTurnFromStreamSnapshot(StreamSnapshot{AlreadyEmittedCalls: true}, BuildOptions{})
if turn.Error != nil {
t.Fatalf("unexpected empty-output error after emitted tool call: %#v", turn.Error)
}
if turn.StopReason != StopReasonToolCalls {
t.Fatalf("stop reason mismatch: %q", turn.StopReason)
}
}

View File

@@ -0,0 +1,170 @@
package completionruntime
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"ds2api/internal/assistantturn"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
)
type DeepSeekCaller interface {
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
}
type Options struct {
StripReferenceMarkers bool
MaxAttempts int
RetryEnabled bool
RetryMaxAttempts int
}
type NonStreamResult struct {
SessionID string
Payload map[string]any
Turn assistantturn.Turn
Attempts int
}
type StartResult struct {
SessionID string
Payload map[string]any
Pow string
Response *http.Response
}
func StartCompletion(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (StartResult, *assistantturn.OutputError) {
maxAttempts := opts.MaxAttempts
if maxAttempts <= 0 {
maxAttempts = 3
}
sessionID, err := ds.CreateSession(ctx, a, maxAttempts)
if err != nil {
return StartResult{}, authOutputError(a)
}
pow, err := ds.GetPow(ctx, a, maxAttempts)
if err != nil {
return StartResult{SessionID: sessionID}, &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Failed to get PoW (invalid token or unknown error).", Code: "error"}
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := ds.CallCompletion(ctx, a, payload, pow, maxAttempts)
if err != nil {
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
}
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow, Response: resp}, nil
}
func ExecuteNonStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (NonStreamResult, *assistantturn.OutputError) {
start, startErr := StartCompletion(ctx, ds, a, stdReq, opts)
if startErr != nil {
return NonStreamResult{SessionID: start.SessionID, Payload: start.Payload}, startErr
}
maxAttempts := opts.MaxAttempts
if maxAttempts <= 0 {
maxAttempts = 3
}
sessionID := start.SessionID
payload := start.Payload
pow := start.Pow
attempts := 0
currentResp := start.Response
usagePrompt := stdReq.PromptTokenText
accumulatedThinking := ""
accumulatedRawThinking := ""
accumulatedToolDetectionThinking := ""
for {
turn, outErr := collectAttempt(currentResp, stdReq, usagePrompt, opts)
if outErr != nil {
return NonStreamResult{SessionID: sessionID, Payload: payload, Attempts: attempts}, outErr
}
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, turn.Thinking)
accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, turn.RawThinking)
accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, turn.DetectionThinking)
turn.Thinking = accumulatedThinking
turn.RawThinking = accumulatedRawThinking
turn.DetectionThinking = accumulatedToolDetectionThinking
turn = assistantturn.BuildTurnFromCollected(sse.CollectResult{
Text: turn.RawText,
Thinking: turn.RawThinking,
ToolDetectionThinking: turn.DetectionThinking,
ContentFilter: turn.ContentFilter,
CitationLinks: turn.CitationLinks,
ResponseMessageID: turn.ResponseMessageID,
}, buildOptions(stdReq, usagePrompt, opts))
retryMax := opts.RetryMaxAttempts
if retryMax <= 0 {
retryMax = shared.EmptyOutputRetryMaxAttempts()
}
if !opts.RetryEnabled || !assistantturn.ShouldRetryEmptyOutput(turn, attempts, retryMax) {
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, turn.Error
}
attempts++
config.Logger.Info("[completion_runtime_empty_retry] attempting synthetic retry", "surface", stdReq.Surface, "stream", false, "retry_attempt", attempts, "parent_message_id", turn.ResponseMessageID)
retryPow, powErr := ds.GetPow(ctx, a, maxAttempts)
if powErr != nil {
config.Logger.Warn("[completion_runtime_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", stdReq.Surface, "retry_attempt", attempts, "error", powErr)
retryPow = pow
}
retryPayload := shared.ClonePayloadForEmptyOutputRetry(payload, turn.ResponseMessageID)
nextResp, err := ds.CallCompletion(ctx, a, retryPayload, retryPow, maxAttempts)
if err != nil {
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
}
usagePrompt = shared.UsagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}
func collectAttempt(resp *http.Response, stdReq promptcompat.StandardRequest, usagePrompt string, opts Options) (assistantturn.Turn, *assistantturn.OutputError) {
defer func() {
if err := resp.Body.Close(); err != nil {
config.Logger.Warn("[completion_runtime] response body close failed", "surface", stdReq.Surface, "error", err)
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
message := strings.TrimSpace(string(body))
if message == "" {
message = http.StatusText(resp.StatusCode)
}
return assistantturn.Turn{}, &assistantturn.OutputError{Status: resp.StatusCode, Message: message, Code: "error"}
}
result := sse.CollectStream(resp, stdReq.Thinking, false)
return assistantturn.BuildTurnFromCollected(result, buildOptions(stdReq, usagePrompt, opts)), nil
}
func buildOptions(stdReq promptcompat.StandardRequest, prompt string, opts Options) assistantturn.BuildOptions {
return assistantturn.BuildOptions{
Model: stdReq.ResponseModel,
Prompt: prompt,
RefFileTokens: stdReq.RefFileTokens,
SearchEnabled: stdReq.Search,
StripReferenceMarkers: opts.StripReferenceMarkers,
ToolNames: stdReq.ToolNames,
ToolsRaw: stdReq.ToolsRaw,
ToolChoice: stdReq.ToolChoice,
}
}
func authOutputError(a *auth.RequestAuth) *assistantturn.OutputError {
if a != nil && a.UseConfigToken {
return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Account token is invalid. Please re-login the account in admin.", Code: "error"}
}
return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Invalid token. If this should be a DS2API key, add it to config.keys first.", Code: "error"}
}
func Errorf(status int, format string, args ...any) *assistantturn.OutputError {
return &assistantturn.OutputError{Status: status, Message: fmt.Sprintf(format, args...), Code: "error"}
}

View File

@@ -0,0 +1,120 @@
package completionruntime
import (
"context"
"io"
"net/http"
"strings"
"testing"
"ds2api/internal/auth"
"ds2api/internal/promptcompat"
)
type fakeDeepSeekCaller struct {
responses []*http.Response
payloads []map[string]any
}
func (f *fakeDeepSeekCaller) CreateSession(context.Context, *auth.RequestAuth, int) (string, error) {
return "session-1", nil
}
func (f *fakeDeepSeekCaller) GetPow(context.Context, *auth.RequestAuth, int) (string, error) {
return "pow", nil
}
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, _ *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
f.payloads = append(f.payloads, payload)
if len(f.responses) == 0 {
return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil
}
resp := f.responses[0]
f.responses = f.responses[1:]
return resp, nil
}
func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) {
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
http.StatusOK,
`data: {"response_message_id":42,"p":"response/content","v":"<tool_calls><invoke name=\"Write\"><parameter name=\"content\">{\"x\":1}</parameter></invoke></tool_calls>"}`,
)}}
stdReq := promptcompat.StandardRequest{
Surface: "test",
ResponseModel: "deepseek-v4-flash",
PromptTokenText: "prompt",
FinalPrompt: "final prompt",
ToolNames: []string{"Write"},
ToolsRaw: []any{map[string]any{
"name": "Write",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
}},
}
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
if outErr != nil {
t.Fatalf("unexpected output error: %#v", outErr)
}
if result.SessionID != "session-1" {
t.Fatalf("session mismatch: %q", result.SessionID)
}
if got := result.Turn.ResponseMessageID; got != 42 {
t.Fatalf("response message id mismatch: %d", got)
}
if len(result.Turn.ToolCalls) != 1 {
t.Fatalf("expected one tool call, got %d", len(result.Turn.ToolCalls))
}
if _, ok := result.Turn.ToolCalls[0].Input["content"].(string); !ok {
t.Fatalf("expected schema-normalized string argument, got %#v", result.Turn.ToolCalls[0].Input["content"])
}
if result.Turn.Usage.InputTokens == 0 || result.Turn.Usage.TotalTokens == 0 {
t.Fatalf("expected usage to be populated, got %#v", result.Turn.Usage)
}
}
func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) {
ds := &fakeDeepSeekCaller{responses: []*http.Response{
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/status","v":"FINISHED"}`),
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":78,"p":"response/content","v":"ok"}`),
}}
stdReq := promptcompat.StandardRequest{
Surface: "test",
ResponseModel: "deepseek-v4-flash",
PromptTokenText: "prompt",
FinalPrompt: "final prompt",
}
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{RetryEnabled: true})
if outErr != nil {
t.Fatalf("unexpected output error: %#v", outErr)
}
if result.Attempts != 1 {
t.Fatalf("expected one retry, got %d", result.Attempts)
}
if len(ds.payloads) != 2 {
t.Fatalf("expected two completion calls, got %d", len(ds.payloads))
}
if got := ds.payloads[1]["parent_message_id"]; got != 77 {
t.Fatalf("retry parent_message_id mismatch: %#v", got)
}
if result.Turn.Text != "ok" {
t.Fatalf("retry text mismatch: %q", result.Turn.Text)
}
}
func sseHTTPResponse(status int, lines ...string) *http.Response {
body := strings.Join(lines, "\n")
if !strings.HasSuffix(body, "\n") {
body += "\n"
}
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}

View File

@@ -1,6 +1,7 @@
package claude
import (
"ds2api/internal/assistantturn"
"ds2api/internal/toolcall"
"fmt"
"time"
@@ -9,6 +10,47 @@ import (
"ds2api/internal/util"
)
func BuildMessageResponseFromTurn(messageID, model string, turn assistantturn.Turn, exposeThinking bool) map[string]any {
content := make([]map[string]any, 0, 4)
if exposeThinking && turn.Thinking != "" {
content = append(content, map[string]any{"type": "thinking", "thinking": turn.Thinking})
}
stopReason := "end_turn"
if len(turn.ToolCalls) > 0 {
stopReason = "tool_use"
for i, tc := range turn.ToolCalls {
content = append(content, map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
"name": tc.Name,
"input": tc.Input,
})
}
} else {
text := turn.Text
if text == "" && exposeThinking {
text = turn.Thinking
}
if text == "" {
text = "抱歉,没有生成有效的响应内容。"
}
content = append(content, map[string]any{"type": "text", "text": text})
}
return map[string]any{
"id": messageID,
"type": "message",
"role": "assistant",
"model": model,
"content": content,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": turn.Usage.InputTokens,
"output_tokens": turn.Usage.OutputTokens,
},
}
}
func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any {
detected := toolcall.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && finalText == "" && finalThinking != "" {

View File

@@ -4,13 +4,19 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"time"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/config"
claudefmt "ds2api/internal/format/claude"
"ds2api/internal/httpapi/requestbody"
"ds2api/internal/promptcompat"
streamengine "ds2api/internal/stream"
"ds2api/internal/translatorcliproxy"
"ds2api/internal/util"
@@ -22,14 +28,90 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
r.Header.Set("anthropic-version", "2023-06-01")
}
if h.OpenAI == nil {
writeClaudeError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
if isClaudeVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, h.Store) {
return
}
if h.proxyViaOpenAI(w, r, h.Store) {
if h.Auth == nil || h.DS == nil {
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, h.Store) {
return
}
writeClaudeError(w, http.StatusInternalServerError, "Claude runtime backend unavailable.")
return
}
writeClaudeError(w, http.StatusBadGateway, "Failed to proxy Claude request.")
if h.handleClaudeDirect(w, r) {
return
}
writeClaudeError(w, http.StatusBadGateway, "Failed to handle Claude request.")
}
func isClaudeVercelProxyRequest(r *http.Request) bool {
if r == nil || r.URL == nil {
return false
}
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
}
func (h *Handler) handleClaudeDirect(w http.ResponseWriter, r *http.Request) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
} else {
writeClaudeError(w, http.StatusBadRequest, "invalid body")
}
return true
}
var req map[string]any
if err := json.Unmarshal(raw, &req); err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return true
}
exposeThinking := false
if enabled, ok := util.ResolveThinkingOverride(req); ok && enabled {
exposeThinking = true
} else if _, ok := util.ResolveThinkingOverride(req); !ok && !util.ToBool(req["stream"]) {
req["thinking"] = map[string]any{"type": "enabled"}
}
norm, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return true
}
a, err := h.Auth.Determine(r)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, err.Error())
return true
}
defer h.Auth.Release(a)
if norm.Standard.Stream {
h.handleClaudeDirectStream(w, r, a, norm.Standard)
return true
}
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, norm.Standard, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
if outErr != nil {
writeClaudeError(w, outErr.Status, outErr.Message)
return true
}
writeJSON(w, http.StatusOK, claudefmt.BuildMessageResponseFromTurn(
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
norm.Standard.ResponseModel,
result.Turn,
exposeThinking,
))
return true
}
func (h *Handler) handleClaudeDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
if outErr != nil {
writeClaudeError(w, outErr.Status, outErr.Message)
return
}
h.handleClaudeStreamRealtime(w, r, start.Response, stdReq.ResponseModel, stdReq.Messages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
}
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {

View File

@@ -1,6 +1,7 @@
package claude
import (
"ds2api/internal/assistantturn"
"ds2api/internal/sse"
"ds2api/internal/toolcall"
"ds2api/internal/toolstream"
@@ -9,7 +10,6 @@ import (
"time"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
func (s *claudeStreamRuntime) closeThinkingBlock() {
@@ -115,18 +115,28 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
s.closeTextBlock()
finalThinking := s.thinking.String()
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: s.rawText.String(),
VisibleText: s.text.String(),
RawThinking: s.rawThinking.String(),
VisibleThinking: s.thinking.String(),
DetectionThinking: s.toolDetectionThinking.String(),
AlreadyEmittedCalls: s.toolCallsDetected,
AlreadyEmittedToolRaw: s.toolCallsDetected,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.promptTokenText,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
})
finalText := turn.Text
if s.bufferToolContent && !s.toolCallsDetected {
detected := toolcall.ParseStandaloneToolCallsDetailed(s.rawText.String(), s.toolNames)
if len(detected.Calls) == 0 {
detected = toolcall.ParseStandaloneToolCallsDetailed(s.rawThinking.String(), s.toolNames)
}
if len(detected.Calls) > 0 {
normalized := toolcall.NormalizeParsedToolCallsForSchemas(detected.Calls, s.toolsRaw)
if len(turn.ToolCalls) > 0 {
stopReason = "tool_use"
for _, tc := range normalized {
for _, tc := range turn.ToolCalls {
idx := s.nextBlockIndex
s.nextBlockIndex++
s.sendToolUseBlock(idx, tc)
@@ -161,7 +171,6 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
stopReason = "tool_use"
}
outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model)
s.send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
@@ -169,7 +178,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
"stop_sequence": nil,
},
"usage": map[string]any{
"output_tokens": outputTokens,
"output_tokens": turn.Usage.OutputTokens,
},
})
s.send("message_stop", map[string]any{"type": "message_stop"})

View File

@@ -33,6 +33,9 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
toolsRaw := convertGeminiTools(req["tools"])
finalPrompt, toolNames := promptcompat.BuildOpenAIPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled)
if len(toolNames) == 0 && len(toolsRaw) > 0 {
toolNames = []string{"__any_tool__"}
}
passThrough := collectGeminiPassThrough(req)
return promptcompat.StandardRequest{
@@ -42,6 +45,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
ResponseModel: requestedModel,
Messages: messagesRaw,
PromptTokenText: finalPrompt,
ToolsRaw: toolsRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,

View File

@@ -11,7 +11,11 @@ import (
"github.com/go-chi/chi/v5"
"ds2api/internal/assistantturn"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/httpapi/requestbody"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
"ds2api/internal/toolcall"
"ds2api/internal/translatorcliproxy"
@@ -21,14 +25,80 @@ import (
)
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
if h.OpenAI == nil {
writeGeminiError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
if isGeminiVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, stream) {
return
}
if h.proxyViaOpenAI(w, r, stream) {
if h.Auth == nil || h.DS == nil {
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, stream) {
return
}
writeGeminiError(w, http.StatusInternalServerError, "Gemini runtime backend unavailable.")
return
}
writeGeminiError(w, http.StatusBadGateway, "Failed to proxy Gemini request.")
if h.handleGeminiDirect(w, r, stream) {
return
}
writeGeminiError(w, http.StatusBadGateway, "Failed to handle Gemini request.")
}
func isGeminiVercelProxyRequest(r *http.Request) bool {
if r == nil || r.URL == nil {
return false
}
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
}
func (h *Handler) handleGeminiDirect(w http.ResponseWriter, r *http.Request, stream bool) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
} else {
writeGeminiError(w, http.StatusBadRequest, "invalid body")
}
return true
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
var req map[string]any
if err := json.Unmarshal(raw, &req); err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
return true
}
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, err.Error())
return true
}
a, err := h.Auth.Determine(r)
if err != nil {
writeGeminiError(w, http.StatusUnauthorized, err.Error())
return true
}
defer h.Auth.Release(a)
if stream {
h.handleGeminiDirectStream(w, r, a, stdReq)
return true
}
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
if outErr != nil {
writeGeminiError(w, outErr.Status, outErr.Message)
return true
}
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponseFromTurn(result.Turn))
return true
}
func (h *Handler) handleGeminiDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
if outErr != nil {
writeGeminiError(w, outErr.Status, outErr.Message)
return
}
h.handleStreamGenerateContent(w, r, start.Response, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
}
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
@@ -250,6 +320,48 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
}
}
func buildGeminiGenerateContentResponseFromTurn(turn assistantturn.Turn) map[string]any {
parts := buildGeminiPartsFromTurn(turn)
return map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
"finishReason": "STOP",
},
},
"modelVersion": turn.Model,
"usageMetadata": map[string]any{
"promptTokenCount": turn.Usage.InputTokens,
"candidatesTokenCount": turn.Usage.OutputTokens,
"totalTokenCount": turn.Usage.TotalTokens,
},
}
}
func buildGeminiPartsFromTurn(turn assistantturn.Turn) []map[string]any {
if len(turn.ToolCalls) > 0 {
parts := make([]map[string]any, 0, len(turn.ToolCalls))
for _, tc := range turn.ToolCalls {
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": tc.Name,
"args": tc.Input,
},
})
}
return parts
}
text := turn.Text
if text == "" {
text = turn.Thinking
}
return []map[string]any{{"text": text}}
}
//nolint:unused // retained for native Gemini non-stream handling path.
func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.CountPromptTokens(finalPrompt, model)

View File

@@ -7,13 +7,14 @@ import (
"strings"
"time"
"ds2api/internal/assistantturn"
dsprotocol "ds2api/internal/deepseek/protocol"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
//nolint:unused // retained for native Gemini stream handling path.
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -28,7 +29,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw)
initialType := "text"
if thinkingEnabled {
@@ -64,9 +65,11 @@ type geminiStreamRuntime struct {
bufferContent bool
stripReferenceMarkers bool
toolNames []string
toolsRaw any
thinking strings.Builder
text strings.Builder
accumulator *assistantturn.Accumulator
contentFilter bool
responseMessageID int
}
//nolint:unused // retained for native Gemini stream handling path.
@@ -80,6 +83,7 @@ func newGeminiStreamRuntime(
searchEnabled bool,
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
) *geminiStreamRuntime {
return &geminiStreamRuntime{
w: w,
@@ -92,6 +96,12 @@ func newGeminiStreamRuntime(
bufferContent: len(toolNames) > 0,
stripReferenceMarkers: stripReferenceMarkers,
toolNames: toolNames,
toolsRaw: toolsRaw,
accumulator: assistantturn.NewAccumulator(assistantturn.AccumulatorOptions{
ThinkingEnabled: thinkingEnabled,
SearchEnabled: searchEnabled,
StripReferenceMarkers: stripReferenceMarkers,
}),
}
}
@@ -111,32 +121,24 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ResponseMessageID > 0 {
s.responseMessageID = parsed.ResponseMessageID
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
if parsed.ContentFilter {
s.contentFilter = true
}
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
cleanedText := cleanVisibleOutput(p.Text, s.stripReferenceMarkers)
if cleanedText == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(cleanedText) {
continue
}
contentSeen = true
accumulated := s.accumulator.Apply(parsed)
for _, p := range accumulated.Parts {
if p.Type == "thinking" {
if s.thinkingEnabled {
if cleanedText != "" {
s.thinking.WriteString(cleanedText)
}
}
continue
}
if cleanedText == "" {
if p.RawText == "" || p.CitationOnly || p.VisibleText == "" {
continue
}
s.text.WriteString(cleanedText)
if s.bufferContent {
continue
}
@@ -146,23 +148,38 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
"index": 0,
"content": map[string]any{
"role": "model",
"parts": []map[string]any{{"text": cleanedText}},
"parts": []map[string]any{{"text": p.VisibleText}},
},
},
},
"modelVersion": s.model,
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
return streamengine.ParsedDecision{ContentSeen: accumulated.ContentSeen}
}
//nolint:unused // retained for native Gemini stream handling path.
func (s *geminiStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: rawText,
VisibleText: text,
RawThinking: rawThinking,
VisibleThinking: thinking,
DetectionThinking: detectionThinking,
ContentFilter: s.contentFilter,
ResponseMessageID: s.responseMessageID,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.finalPrompt,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
})
if s.bufferContent {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
parts := buildGeminiPartsFromTurn(turn)
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
@@ -190,7 +207,11 @@ func (s *geminiStreamRuntime) finalize() {
"finishReason": "STOP",
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText),
"modelVersion": s.model,
"usageMetadata": map[string]any{
"promptTokenCount": turn.Usage.InputTokens,
"candidatesTokenCount": turn.Usage.OutputTokens,
"totalTokenCount": turn.Usage.TotalTokens,
},
})
}

View File

@@ -5,8 +5,10 @@ import (
"net/http"
"strings"
"ds2api/internal/assistantturn"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/toolstream"
@@ -24,6 +26,7 @@ type chatStreamRuntime struct {
refFileTokens int
toolNames []string
toolsRaw any
toolChoice promptcompat.ToolChoicePolicy
thinkingEnabled bool
searchEnabled bool
@@ -89,6 +92,7 @@ func newChatStreamRuntime(
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
toolChoice promptcompat.ToolChoicePolicy,
bufferToolContent bool,
emitEarlyToolDeltas bool,
) *chatStreamRuntime {
@@ -102,6 +106,7 @@ func newChatStreamRuntime(
finalPrompt: finalPrompt,
toolNames: toolNames,
toolsRaw: toolsRaw,
toolChoice: toolChoice,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
stripReferenceMarkers: stripReferenceMarkers,
@@ -201,14 +206,33 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.finalErrorCode = ""
finalThinking := s.accumulator.Thinking.String()
finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String()
finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers)
s.finalThinking = finalThinking
s.finalText = finalText
detected := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames)
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
finalText := s.accumulator.Text.String()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: s.accumulator.RawText.String(),
VisibleText: finalText,
RawThinking: s.accumulator.RawThinking.String(),
VisibleThinking: finalThinking,
DetectionThinking: finalToolDetectionThinking,
ContentFilter: finishReason == "content_filter",
ResponseMessageID: s.responseMessageID,
AlreadyEmittedCalls: s.toolCallsEmitted,
AlreadyEmittedToolRaw: s.toolCallsDoneEmitted,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.finalPrompt,
RefFileTokens: s.refFileTokens,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
ToolChoice: s.toolChoice,
})
s.finalThinking = turn.Thinking
s.finalText = turn.Text
if len(turn.ToolCalls) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
s.sendDelta(map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(turn.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
})
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
@@ -237,11 +261,14 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
batch.flush()
}
if len(detected.Calls) > 0 || s.toolCallsEmitted {
if len(turn.ToolCalls) > 0 || s.toolCallsEmitted {
finishReason = "tool_calls"
}
if len(detected.Calls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(finalText) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking)
if len(turn.ToolCalls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(turn.Text) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking)
if turn.Error != nil {
status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code
}
if deferEmptyOutput {
s.finalErrorStatus = status
s.finalErrorMessage = message
@@ -251,7 +278,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.sendFailedChunk(status, message, code)
return true
}
usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText, s.refFileTokens)
usage := chatUsageFromTurn(turn)
s.finalFinishReason = finishReason
s.finalUsage = usage
s.sendChunk(openaifmt.BuildChatStreamChunk(
@@ -265,6 +292,17 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
return true
}
func chatUsageFromTurn(turn assistantturn.Turn) map[string]any {
return map[string]any{
"prompt_tokens": turn.Usage.InputTokens,
"completion_tokens": turn.Usage.OutputTokens,
"total_tokens": turn.Usage.TotalTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": turn.Usage.ReasoningTokens,
},
}
}
func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}

View File

@@ -6,6 +6,8 @@ import (
"strings"
"testing"
"time"
"ds2api/internal/promptcompat"
)
func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
@@ -23,6 +25,7 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
true,
nil,
nil,
promptcompat.DefaultToolChoicePolicy(),
false,
false,
)
@@ -51,3 +54,34 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
t.Fatalf("expected empty choices heartbeat, got %#v", choices)
}
}
func TestChatStreamFinalizeEnforcesRequiredToolChoice(t *testing.T) {
rec := httptest.NewRecorder()
runtime := newChatStreamRuntime(
rec,
http.NewResponseController(rec),
true,
"chatcmpl-test",
time.Now().Unix(),
"deepseek-v4-flash",
"prompt",
false,
false,
true,
[]string{"Write"},
nil,
promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired},
true,
false,
)
if !runtime.finalize("stop", false) {
t.Fatalf("expected terminal error to be written")
}
if runtime.finalErrorCode != "tool_choice_violation" {
t.Fatalf("expected tool_choice_violation, got %q body=%s", runtime.finalErrorCode, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "tool_choice requires") {
t.Fatalf("expected tool choice error in stream body, got %s", rec.Body.String())
}
}

View File

@@ -7,10 +7,12 @@ import (
"strings"
"time"
"ds2api/internal/assistantturn"
"ds2api/internal/auth"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
@@ -26,6 +28,7 @@ type chatNonStreamResult struct {
body map[string]any
finishReason string
responseMessageID int
outputError *assistantturn.OutputError
}
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
@@ -86,35 +89,40 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
return chatNonStreamResult{}, false
}
result := sse.CollectStream(resp, thinkingEnabled, true)
stripReferenceMarkers := h.compatStripReferenceMarkers()
finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
if searchEnabled {
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
}
detected := detectAssistantToolCalls(result.Text, finalText, result.Thinking, result.ToolDetectionThinking, toolNames)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
turn := assistantturn.BuildTurnFromCollected(result, assistantturn.BuildOptions{
Model: model,
Prompt: usagePrompt,
SearchEnabled: searchEnabled,
StripReferenceMarkers: h.compatStripReferenceMarkers(),
ToolNames: toolNames,
ToolsRaw: toolsRaw,
})
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, turn.Thinking, turn.Text, turn.ToolCalls, toolsRaw)
return chatNonStreamResult{
rawThinking: result.Thinking,
rawText: result.Text,
thinking: finalThinking,
thinking: turn.Thinking,
toolDetectionThinking: result.ToolDetectionThinking,
text: finalText,
text: turn.Text,
contentFilter: result.ContentFilter,
detectedCalls: len(detected.Calls),
detectedCalls: len(turn.ToolCalls),
body: respBody,
finishReason: chatFinishReason(respBody),
responseMessageID: result.ResponseMessageID,
outputError: turn.Error,
}, true
}
func (h *Handler) finishChatNonStreamResult(w http.ResponseWriter, result chatNonStreamResult, attempts int, usagePrompt string, refFileTokens int, historySession *chatHistorySession) {
if result.detectedCalls == 0 && shouldWriteUpstreamEmptyOutputError(result.text, result.thinking) {
if result.detectedCalls == 0 && strings.TrimSpace(result.text) == "" {
status, message, code := upstreamEmptyOutputDetail(result.contentFilter, result.text, result.thinking)
if result.outputError != nil {
status, message, code = result.outputError.Status, result.outputError.Message, result.outputError.Code
}
if historySession != nil {
historySession.error(status, message, code, result.thinking, result.text)
}
writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter)
writeOpenAIErrorWithCode(w, status, message, code)
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "chat.completions", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter)
return
}
@@ -147,8 +155,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
strings.TrimSpace(result.thinking) == ""
}
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, historySession)
if !ok {
return
}
@@ -190,7 +198,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
}
}
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -216,6 +224,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
streamRuntime := newChatStreamRuntime(
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
toolChoice,
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
)
streamRuntime.refFileTokens = refFileTokens

View File

@@ -8,6 +8,7 @@ import (
"time"
"ds2api/internal/chathistory"
"ds2api/internal/promptcompat"
"ds2api/internal/stream"
)
@@ -48,6 +49,7 @@ func TestConsumeChatStreamAttemptMarksContextCancelledState(t *testing.T) {
true,
nil,
nil,
promptcompat.DefaultToolChoicePolicy(),
false,
false,
)

View File

@@ -80,6 +80,10 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string) {
shared.WriteOpenAIError(w, status, message)
}
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
shared.WriteOpenAIErrorWithCode(w, status, message, code)
}
func openAIErrorType(status int) string {
return shared.OpenAIErrorType(status)
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
@@ -76,44 +77,40 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
}
historySession := startChatHistory(h.ChatHistory, r, a, stdReq)
sessionID, err = h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
if !stdReq.Stream {
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
sessionID = result.SessionID
if outErr != nil {
if historySession != nil {
historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "")
historySession.error(outErr.Status, outErr.Message, outErr.Code, result.Turn.Thinking, result.Turn.Text)
}
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
if historySession != nil {
historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "")
}
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
respBody := openaifmt.BuildChatCompletionWithToolCalls(result.SessionID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw)
respBody["usage"] = chatUsageFromTurn(result.Turn)
finishReason := chatFinishReason(respBody)
if historySession != nil {
historySession.success(http.StatusOK, result.Turn.Thinking, result.Turn.Text, finishReason, chatUsageFromTurn(result.Turn))
}
writeJSON(w, http.StatusOK, respBody)
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
sessionID = start.SessionID
if outErr != nil {
if historySession != nil {
historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "")
historySession.error(outErr.Status, outErr.Message, outErr.Code, "", "")
}
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
if historySession != nil {
historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "")
}
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
refFileTokens := stdReq.RefFileTokens
if stdReq.Stream {
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
return
}
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
h.handleStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, historySession)
}
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
@@ -234,6 +231,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
stripReferenceMarkers,
toolNames,
toolsRaw,
promptcompat.DefaultToolChoicePolicy(),
bufferToolContent,
emitEarlyToolDeltas,
)

View File

@@ -1,7 +1,6 @@
package responses
import (
"context"
"io"
"net/http"
"strings"
@@ -10,129 +9,10 @@ import (
"ds2api/internal/auth"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/toolcall"
)
type responsesNonStreamResult struct {
rawThinking string
rawText string
thinking string
toolDetectionThinking string
text string
contentFilter bool
parsed toolcall.ToolCallParseResult
body map[string]any
responseMessageID int
}
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
attempts := 0
currentResp := resp
usagePrompt := finalPrompt
accumulatedThinking := ""
accumulatedRawThinking := ""
accumulatedToolDetectionThinking := ""
for {
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
if !ok {
return
}
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, result.thinking)
accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, result.rawThinking)
accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, result.toolDetectionThinking)
result.thinking = accumulatedThinking
result.rawThinking = accumulatedRawThinking
result.toolDetectionThinking = accumulatedToolDetectionThinking
result.parsed = detectAssistantToolCalls(result.rawText, result.text, result.rawThinking, result.toolDetectionThinking, toolNames)
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
if refFileTokens > 0 {
addRefFileTokensToUsage(result.body, refFileTokens)
}
if !shouldRetryResponsesNonStream(result, attempts) {
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
return
}
attempts++
config.Logger.Info("[openai_empty_retry] attempting synthetic retry", "surface", "responses", "stream", false, "retry_attempt", attempts, "parent_message_id", result.responseMessageID)
retryPow, powErr := h.DS.GetPow(ctx, a, 3)
if powErr != nil {
config.Logger.Warn("[openai_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", powErr)
retryPow = pow
}
nextResp, err := h.DS.CallCompletion(ctx, a, clonePayloadForEmptyOutputRetry(payload, result.responseMessageID), retryPow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err)
return
}
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return responsesNonStreamResult{}, false
}
result := sse.CollectStream(resp, thinkingEnabled, false)
stripReferenceMarkers := h.compatStripReferenceMarkers()
sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
if searchEnabled {
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
}
textParsed := detectAssistantToolCalls(result.Text, sanitizedText, result.Thinking, result.ToolDetectionThinking, toolNames)
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
return responsesNonStreamResult{
rawThinking: result.Thinking,
rawText: result.Text,
thinking: sanitizedThinking,
toolDetectionThinking: result.ToolDetectionThinking,
text: sanitizedText,
contentFilter: result.ContentFilter,
parsed: textParsed,
body: responseObj,
responseMessageID: result.ResponseMessageID,
}, true
}
func (h *Handler) finishResponsesNonStreamResult(w http.ResponseWriter, result responsesNonStreamResult, attempts int, owner, responseID string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
if len(result.parsed.Calls) == 0 && writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter) {
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter)
return
}
logResponsesToolPolicyRejection(traceID, toolChoice, result.parsed, "text")
if toolChoice.IsRequired() && len(result.parsed.Calls) == 0 {
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
return
}
h.getResponseStore().put(owner, responseID, result.body)
writeJSON(w, http.StatusOK, result.body)
source := "first_attempt"
if attempts > 0 {
source = "synthetic_retry"
}
config.Logger.Info("[openai_empty_retry] completed", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", source)
}
func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int) bool {
return emptyOutputRetryEnabled() &&
attempts < emptyOutputRetryMaxAttempts() &&
!result.contentFilter &&
len(result.parsed.Calls) == 0 &&
strings.TrimSpace(result.text) == "" &&
strings.TrimSpace(result.thinking) == ""
}
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
if !ok {

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
@@ -92,34 +93,31 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if !stdReq.Stream {
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
if outErr != nil {
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw)
responseObj["usage"] = responsesUsageFromTurn(result.Turn)
h.getResponseStore().put(owner, responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
return
}
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
refFileTokens := stdReq.RefFileTokens
if stdReq.Stream {
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
if outErr != nil {
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
refFileTokens := stdReq.RefFileTokens
h.handleResponsesStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {

View File

@@ -1,6 +1,7 @@
package responses
import (
"ds2api/internal/assistantturn"
"ds2api/internal/toolcall"
"net/http"
"strings"
@@ -159,9 +160,29 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
finalThinking := s.accumulator.Thinking.String()
finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String()
finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers)
textParsed := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames)
detected := textParsed.Calls
finalText := s.accumulator.Text.String()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: s.accumulator.RawText.String(),
VisibleText: finalText,
RawThinking: s.accumulator.RawThinking.String(),
VisibleThinking: finalThinking,
DetectionThinking: finalToolDetectionThinking,
ContentFilter: finishReason == "content_filter",
ResponseMessageID: s.responseMessageID,
AlreadyEmittedCalls: s.toolCallsEmitted,
AlreadyEmittedToolRaw: s.toolCallsDoneEmitted,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.finalPrompt,
RefFileTokens: s.refFileTokens,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
ToolChoice: s.toolChoice,
})
textParsed := turn.ParsedToolCalls
detected := turn.ToolCalls
s.logToolPolicyRejections(textParsed)
if len(detected) > 0 {
@@ -173,12 +194,15 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
s.closeMessageItem()
if s.toolChoice.IsRequired() && len(detected) == 0 {
s.failResponse(http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
if turn.Error != nil && turn.Error.Code == "tool_choice_violation" {
s.failResponse(turn.Error.Status, turn.Error.Message, turn.Error.Code)
return true
}
if len(detected) == 0 && strings.TrimSpace(finalText) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking)
if len(detected) == 0 && strings.TrimSpace(turn.Text) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking)
if turn.Error != nil {
status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code
}
if deferEmptyOutput {
s.finalErrorStatus = status
s.finalErrorMessage = message
@@ -190,7 +214,7 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
}
s.closeIncompleteFunctionItems()
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
obj := s.buildCompletedResponseObject(turn.Thinking, turn.Text, detected)
if s.persistResponse != nil {
s.persistResponse(obj)
}
@@ -199,6 +223,14 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
return true
}
func responsesUsageFromTurn(turn assistantturn.Turn) map[string]any {
return map[string]any{
"input_tokens": turn.Usage.InputTokens,
"output_tokens": turn.Usage.OutputTokens,
"total_tokens": turn.Usage.TotalTokens,
}
}
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed toolcall.ToolCallParseResult) {
logRejected := func(parsed toolcall.ToolCallParseResult, channel string) {
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)