From dc5bffdf89e0bebe5afad84c75684aad2ed66780 Mon Sep 17 00:00:00 2001 From: CJACK Date: Sat, 2 May 2026 23:28:43 +0800 Subject: [PATCH] refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages --- VERSION | 2 +- docs/ARCHITECTURE.en.md | 4 + docs/ARCHITECTURE.md | 4 + docs/prompt-compatibility.md | 9 +- internal/assistantturn/stream.go | 64 +++++ internal/assistantturn/turn.go | 227 ++++++++++++++++++ internal/assistantturn/turn_test.go | 100 ++++++++ internal/completionruntime/nonstream.go | 170 +++++++++++++ internal/completionruntime/nonstream_test.go | 120 +++++++++ internal/format/claude/render.go | 42 ++++ internal/httpapi/claude/handler_messages.go | 90 ++++++- .../httpapi/claude/stream_runtime_finalize.go | 33 ++- internal/httpapi/gemini/convert_request.go | 4 + internal/httpapi/gemini/handler_generate.go | 120 ++++++++- .../httpapi/gemini/handler_stream_runtime.go | 77 +++--- .../openai/chat/chat_stream_runtime.go | 58 ++++- .../openai/chat/chat_stream_runtime_test.go | 34 +++ .../openai/chat/empty_retry_runtime.go | 41 ++-- .../openai/chat/empty_retry_runtime_test.go | 2 + internal/httpapi/openai/chat/handler.go | 4 + internal/httpapi/openai/chat/handler_chat.go | 54 ++--- .../openai/responses/empty_retry_runtime.go | 120 --------- .../openai/responses/responses_handler.go | 42 ++-- .../responses_stream_runtime_core.go | 48 +++- 24 files changed, 1215 insertions(+), 254 deletions(-) create mode 100644 internal/assistantturn/stream.go create mode 100644 internal/assistantturn/turn.go create mode 100644 internal/assistantturn/turn_test.go create mode 100644 internal/completionruntime/nonstream.go create mode 100644 internal/completionruntime/nonstream_test.go diff --git a/VERSION b/VERSION index f77856a..fdc6698 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -4.3.1 +4.4.0 diff --git a/docs/ARCHITECTURE.en.md b/docs/ARCHITECTURE.en.md index df755f2..c7ab027 100644 --- a/docs/ARCHITECTURE.en.md +++ b/docs/ARCHITECTURE.en.md @@ -25,6 +25,8 @@ ds2api/ │ ├── chathistory/ # Server-side conversation history storage/query │ ├── claudeconv/ # Claude message conversion helpers │ ├── compat/ # Compatibility and regression helpers +│ ├── assistantturn/ # Upstream output to canonical assistant turn / stream event semantics +│ ├── completionruntime/ # Shared Go DeepSeek completion startup, non-stream collection, and retry │ ├── config/ # Config loading/validation/hot reload │ ├── deepseek/ # DeepSeek upstream client/protocol/transport │ │ ├── client/ # Login/session/completion/upload/delete calls @@ -171,6 +173,8 @@ flowchart LR - `internal/httpapi/openai/*`: OpenAI HTTP surface split into chat, responses, files, embeddings, history, and shared packages; chat/responses share the promptcompat, stream, and toolcall semantics. - `internal/httpapi/{claude,gemini}`: protocol wrappers that normalize into the same prompt compatibility semantics without duplicating upstream execution. - `internal/promptcompat`: compatibility core for turning OpenAI/Claude/Gemini requests into DeepSeek web-chat plain-text context. +- `internal/assistantturn`: Go output-side canonical semantics, converting DeepSeek SSE collection results and stream finalization state into assistant turns and centralizing thinking, tool call, citation, usage, stop/error behavior. +- `internal/completionruntime`: shared Go completion execution helpers for DeepSeek session/PoW/call startup, non-stream collection, and empty-output retry; streaming paths use it to start upstream requests, continue to use `internal/stream` for real-time consumption, and use `assistantturn` during finalization. - `internal/translatorcliproxy`: structure translation between Claude/Gemini and OpenAI. - `internal/deepseek/{client,protocol,transport}`: upstream requests, sessions, PoW adaptation, protocol constants, and transport details. - `internal/js/chat-stream` + `api/chat-stream.js`: Vercel Node streaming bridge; Go prepare/release owns auth, account lease, and completion payload assembly, while Node relays real-time SSE with Go-aligned finalization and tool sieve semantics. diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 8123ba5..a7e7369 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -25,6 +25,8 @@ ds2api/ │ ├── chathistory/ # 服务器端对话记录存储与查询 │ ├── claudeconv/ # Claude 消息格式转换工具 │ ├── compat/ # 兼容性辅助与回归支持 +│ ├── assistantturn/ # 上游输出到统一 assistant turn / stream event 的语义层 +│ ├── completionruntime/ # Go 主路径共享 DeepSeek completion 启动、非流式收集与 retry │ ├── config/ # 配置加载、校验、热更新 │ ├── deepseek/ # DeepSeek 上游 client/protocol/transport │ │ ├── client/ # 登录、会话、completion、上传/删除等上游调用 @@ -171,6 +173,8 @@ flowchart LR - `internal/httpapi/openai/*`:OpenAI HTTP surface,按 chat、responses、files、embeddings、history、shared 拆分;chat/responses 共享 promptcompat、stream、toolcall 等核心语义。 - `internal/httpapi/{claude,gemini}`:协议输入输出适配,归一到同一套 prompt compatibility 语义,不重复实现上游调用逻辑。 - `internal/promptcompat`:OpenAI/Claude/Gemini 请求到 DeepSeek 网页纯文本上下文的兼容内核。 +- `internal/assistantturn`:Go 输出侧统一语义层,把 DeepSeek SSE 收集结果和流式收尾状态归一成 assistant turn,集中处理 thinking、tool call、citation、usage、stop/error 语义。 +- `internal/completionruntime`:Go surface 共享的 completion 执行辅助,负责 DeepSeek session/PoW/call 启动、非流式 collect 和 empty-output retry;流式路径复用它启动上游请求,继续用 `internal/stream` 做实时消费,并在最终收尾阶段接入 `assistantturn`。 - `internal/translatorcliproxy`:Claude/Gemini 与 OpenAI 结构互转。 - `internal/deepseek/{client,protocol,transport}`:上游请求、会话、PoW 适配、协议常量与传输层。 - `internal/js/chat-stream` + `api/chat-stream.js`:Vercel Node 流式桥;Go prepare/release 管理鉴权、账号租约和 completion payload,Node 侧负责实时 SSE 转发并保持 Go 对齐的终结态和 tool sieve 语义。 diff --git a/docs/prompt-compatibility.md b/docs/prompt-compatibility.md index 8c293d3..c0ace30 100644 --- a/docs/prompt-compatibility.md +++ b/docs/prompt-compatibility.md @@ -48,6 +48,8 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` -> 文件收集 / inline 上传 / current input file(OpenAI 链路) -> completion payload -> 下游网页对话接口 + -> assistantturn 输出语义归一(Go 非流式 + 流式收尾) + -> 各协议 renderer(OpenAI / Responses / Claude / Gemini) ``` 对应的关键代码入口: @@ -72,6 +74,10 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` [internal/promptcompat/thinking_injection.go](../internal/promptcompat/thinking_injection.go) - completion payload: [internal/promptcompat/standard_request.go](../internal/promptcompat/standard_request.go) +- Go 输出侧 assistant turn: + [internal/assistantturn/turn.go](../internal/assistantturn/turn.go) +- Go completion runtime: + [internal/completionruntime/nonstream.go](../internal/completionruntime/nonstream.go) ## 4. 下游真正收到的东西 @@ -101,7 +107,8 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` - 对外返回给客户端的 `prompt_tokens` / `input_tokens` / `promptTokenCount` 不再按“最后一条消息”或字符粗估近似返回,而是基于**完整上下文 prompt**做 tokenizer 计数;为了避免上下文实际超限但客户端误以为还能塞下,请求侧上下文 token 会额外保守上浮一点,宁可略大也不低估。 - 当前 `/v1/chat/completions` 业务路径仍是“每次请求新建一个远端 `chat_session_id`,并默认发送 `parent_message_id: null`”;因此 DS2API 对外默认表现为“新会话 + prompt 拼历史”,而不是复用 DeepSeek 原生会话树。 - 但 DeepSeek 远端本身支持同一 `chat_session_id` 的跨轮次持续对话。2026-04-27 已用项目内现有 DeepSeek client 做过一次不改业务代码的双轮实测:同一 `chat_session_id` 下,第 1 轮返回 `request_message_id=1` / `response_message_id=2` / 文本 `SESSION_TEST_ONE`;第 2 轮重新获取一次 PoW,并发送 `parent_message_id=2` 后,成功返回 `request_message_id=3` / `response_message_id=4` / 文本 `SESSION_TEST_TWO`。这说明“同远端会话持续聊天”能力存在,且每轮需要携带正确的 parent/message 链接信息,同时重新获取对应轮次可用的 PoW。 -- OpenAI Chat / Responses 原生走统一 OpenAI 标准化与 DeepSeek payload 组装;Claude / Gemini 会尽量复用 OpenAI prompt/tool 语义,其中 Gemini 直接复用 `promptcompat.BuildOpenAIPromptForAdapter`,Claude 消息接口在可代理场景会转换为 OpenAI chat 形态再执行。 +- OpenAI Chat / Responses 原生走统一 OpenAI 标准化与 DeepSeek payload 组装;Claude / Gemini 会尽量复用 OpenAI prompt/tool 语义,其中 Gemini 直接复用 `promptcompat.BuildOpenAIPromptForAdapter`。Go 主服务新增 `completionruntime` 启动层,统一执行 DeepSeek session/PoW/call;输出侧新增 `assistantturn` 语义层:非流式 OpenAI Chat / Responses / Claude / Gemini 会把 DeepSeek SSE 收集结果先归一成同一份 assistant turn,再分别渲染成各协议原生外形;流式 OpenAI Chat / Responses / Claude / Gemini 继续保持各协议实时 SSE framing,但最终收尾的 tool fallback、schema 归一、usage、empty-output / content-filter 错误语义同样由 `assistantturn` 判定。Claude / Gemini 的常规 Go 主路径不再依赖内部 `httptest` 转发到 OpenAI handler;`translatorcliproxy` 仍保留用于 Vercel bridge、兼容工具和回归测试。 +- Vercel Node 流式路径本轮不迁移,仍使用现有 Node bridge / stream-tool-sieve 实现;后续若变更 Node 流式语义,需要按 `assistantturn` 的 Go canonical 输出语义同步对齐。 - 客户端传入的 thinking / reasoning 开关会被归一到下游 `thinking_enabled`。Gemini `generationConfig.thinkingConfig.thinkingBudget` 会翻译成同一套 thinking 开关;关闭时即使上游返回 `response/thinking_content`,兼容层也不会把它当作可见正文输出。若最终解析出的模型名带 `-nothinking` 后缀,则会无条件强制关闭 thinking,优先级高于请求体中的 `thinking` / `reasoning` / `reasoning_effort`。Claude surface 在流式请求且未显式声明 `thinking` 时,仍按 Anthropic 语义默认关闭;但在非流式代理场景,兼容层会内部开启一次下游 thinking,用于捕获“正文为空、工具调用落在 thinking 里”的情况,随后在回包前剥离用户不可见的 thinking block。 - 对 OpenAI Chat / Responses 的非流式收尾,如果最终可见正文为空,兼容层会优先尝试把思维链中的独立 DSML / XML 工具块当作真实工具调用解析出来。流式链路也会在收尾阶段做同样的 fallback 检测,但不会因为思维链内容去中途拦截或改写流式输出;真正的工具识别始终基于原始上游文本,而不是基于“已经做过可见输出清洗”的版本,因此即使最终可见层会剥离完整 leaked DSML / XML `tool_calls` wrapper、并抑制全空参数或无效 wrapper 块,也不会影响真实工具调用转成结构化 `tool_calls` / `function_call`。补发结果会作为本轮 assistant 的结构化 `tool_calls` / `function_call` 输出返回,而不是塞进 `content` 文本;如果客户端没有开启 thinking / reasoning,思维链只用于检测,不会作为 `reasoning_content` 或可见正文暴露。只有正文为空且思维链里也没有可执行工具调用时,才继续按空回复错误处理。 - OpenAI Chat / Responses 的空回复错误处理之前会默认做一次内部补偿重试:第一次上游完整结束后,如果最终可见正文为空、没有解析到工具调用、也没有已经向客户端流式发出工具调用,并且终止原因不是 `content_filter`,兼容层会复用同一个 `chat_session_id`、账号、token 与工具策略,把原始 completion `prompt` 追加固定后缀 `Previous reply had no visible output. Please regenerate the visible final answer or tool call now.` 后重新提交一次。重试遵循 DeepSeek 多轮对话协议:从第一次上游 SSE 流中提取 `response_message_id`,并在重试 payload 中设置 `parent_message_id` 为该值,使重试成为同一会话的后续轮次而非断裂的根消息;同时重新获取一次 PoW(若 PoW 获取失败则回退到原始 PoW)。该重试不会重新标准化消息、不会新建 session、不会切换账号,也不会向流式客户端插入重试标记;第二次 thinking / reasoning 会按正常增量直接接到第一次之后,并继续使用 overlap trim 去重。若第二次仍为空,终端错误码仍保持现有 `upstream_empty_output`;若任一尝试触发空 `content_filter`,不做补偿重试并保持 `content_filter` 错误。JS Vercel 运行时同样设置 `parent_message_id`,但因无法直接调用 PoW API 而复用原始 PoW。 diff --git a/internal/assistantturn/stream.go b/internal/assistantturn/stream.go new file mode 100644 index 0000000..77c398d --- /dev/null +++ b/internal/assistantturn/stream.go @@ -0,0 +1,64 @@ +package assistantturn + +import ( + "ds2api/internal/httpapi/openai/shared" + "ds2api/internal/sse" +) + +type StreamEventType string + +const ( + StreamEventTextDelta StreamEventType = "text_delta" + StreamEventThinkingDelta StreamEventType = "thinking_delta" + StreamEventToolCall StreamEventType = "tool_call" + StreamEventDone StreamEventType = "done" + StreamEventError StreamEventType = "error" + StreamEventPing StreamEventType = "ping" +) + +type StreamEvent struct { + Type StreamEventType + Text string + Thinking string + ToolCall any + Error *OutputError + Usage *Usage +} + +type Accumulator struct { + inner shared.StreamAccumulator +} + +type AccumulatorOptions struct { + ThinkingEnabled bool + SearchEnabled bool + StripReferenceMarkers bool +} + +func NewAccumulator(opts AccumulatorOptions) *Accumulator { + return &Accumulator{ + inner: shared.StreamAccumulator{ + ThinkingEnabled: opts.ThinkingEnabled, + SearchEnabled: opts.SearchEnabled, + StripReferenceMarkers: opts.StripReferenceMarkers, + }, + } +} + +func (a *Accumulator) Apply(parsed sse.LineResult) shared.StreamAccumulatorResult { + if a == nil { + return shared.StreamAccumulatorResult{} + } + return a.inner.Apply(parsed) +} + +func (a *Accumulator) Snapshot() (rawText, text, rawThinking, thinking, detectionThinking string) { + if a == nil { + return "", "", "", "", "" + } + return a.inner.RawText.String(), + a.inner.Text.String(), + a.inner.RawThinking.String(), + a.inner.Thinking.String(), + a.inner.ToolDetectionThinking.String() +} diff --git a/internal/assistantturn/turn.go b/internal/assistantturn/turn.go new file mode 100644 index 0000000..4f8b36d --- /dev/null +++ b/internal/assistantturn/turn.go @@ -0,0 +1,227 @@ +package assistantturn + +import ( + "net/http" + "strings" + + "ds2api/internal/httpapi/openai/shared" + "ds2api/internal/promptcompat" + "ds2api/internal/sse" + "ds2api/internal/toolcall" + "ds2api/internal/util" +) + +type StopReason string + +const ( + StopReasonStop StopReason = "stop" + StopReasonToolCalls StopReason = "tool_calls" + StopReasonContentFilter StopReason = "content_filter" + StopReasonError StopReason = "error" +) + +type Usage struct { + InputTokens int + OutputTokens int + ReasoningTokens int + TotalTokens int +} + +type OutputError struct { + Status int + Message string + Code string +} + +type Turn struct { + Model string + Prompt string + RawText string + RawThinking string + DetectionThinking string + Text string + Thinking string + ToolCalls []toolcall.ParsedToolCall + ParsedToolCalls toolcall.ToolCallParseResult + CitationLinks map[int]string + ContentFilter bool + ResponseMessageID int + StopReason StopReason + Usage Usage + Error *OutputError +} + +type BuildOptions struct { + Model string + Prompt string + RefFileTokens int + SearchEnabled bool + StripReferenceMarkers bool + ToolNames []string + ToolsRaw any + ToolChoice promptcompat.ToolChoicePolicy +} + +type StreamSnapshot struct { + RawText string + VisibleText string + RawThinking string + VisibleThinking string + DetectionThinking string + ContentFilter bool + CitationLinks map[int]string + ResponseMessageID int + AlreadyEmittedCalls bool + AdditionalToolCalls []toolcall.ParsedToolCall + AlreadyEmittedToolRaw bool +} + +func BuildTurnFromCollected(result sse.CollectResult, opts BuildOptions) Turn { + thinking := shared.CleanVisibleOutput(result.Thinking, opts.StripReferenceMarkers) + text := shared.CleanVisibleOutput(result.Text, opts.StripReferenceMarkers) + if opts.SearchEnabled { + text = shared.ReplaceCitationMarkersWithLinks(text, result.CitationLinks) + } + + parsed := shared.DetectAssistantToolCalls(result.Text, text, result.Thinking, result.ToolDetectionThinking, opts.ToolNames) + calls := toolcall.NormalizeParsedToolCallsForSchemas(parsed.Calls, opts.ToolsRaw) + parsed.Calls = calls + + stopReason := StopReasonStop + if result.ContentFilter { + stopReason = StopReasonContentFilter + } + if len(calls) > 0 { + stopReason = StopReasonToolCalls + } + + turn := Turn{ + Model: opts.Model, + Prompt: opts.Prompt, + RawText: result.Text, + RawThinking: result.Thinking, + DetectionThinking: result.ToolDetectionThinking, + Text: text, + Thinking: thinking, + ToolCalls: calls, + ParsedToolCalls: parsed, + CitationLinks: result.CitationLinks, + ContentFilter: result.ContentFilter, + ResponseMessageID: result.ResponseMessageID, + StopReason: stopReason, + } + turn.Usage = BuildUsage(opts.Model, opts.Prompt, thinking, text, opts.RefFileTokens) + turn.Error = ValidateTurn(turn, opts.ToolChoice) + if turn.Error != nil { + turn.StopReason = StopReasonError + } + return turn +} + +func BuildTurnFromStreamSnapshot(snapshot StreamSnapshot, opts BuildOptions) Turn { + thinking := shared.CleanVisibleOutput(snapshot.VisibleThinking, opts.StripReferenceMarkers) + text := shared.CleanVisibleOutput(snapshot.VisibleText, opts.StripReferenceMarkers) + if opts.SearchEnabled { + text = shared.ReplaceCitationMarkersWithLinks(text, snapshot.CitationLinks) + } + + parsed := shared.DetectAssistantToolCalls(snapshot.RawText, text, snapshot.RawThinking, snapshot.DetectionThinking, opts.ToolNames) + calls := parsed.Calls + if len(calls) == 0 && len(snapshot.AdditionalToolCalls) > 0 { + calls = snapshot.AdditionalToolCalls + } + calls = toolcall.NormalizeParsedToolCallsForSchemas(calls, opts.ToolsRaw) + parsed.Calls = calls + + stopReason := StopReasonStop + if snapshot.ContentFilter { + stopReason = StopReasonContentFilter + } + if len(calls) > 0 || snapshot.AlreadyEmittedCalls || snapshot.AlreadyEmittedToolRaw { + stopReason = StopReasonToolCalls + } + + turn := Turn{ + Model: opts.Model, + Prompt: opts.Prompt, + RawText: snapshot.RawText, + RawThinking: snapshot.RawThinking, + DetectionThinking: snapshot.DetectionThinking, + Text: text, + Thinking: thinking, + ToolCalls: calls, + ParsedToolCalls: parsed, + CitationLinks: snapshot.CitationLinks, + ContentFilter: snapshot.ContentFilter, + ResponseMessageID: snapshot.ResponseMessageID, + StopReason: stopReason, + } + turn.Usage = BuildUsage(opts.Model, opts.Prompt, thinking, text, opts.RefFileTokens) + if !snapshot.AlreadyEmittedCalls && !snapshot.AlreadyEmittedToolRaw { + turn.Error = ValidateTurn(turn, opts.ToolChoice) + } + if turn.Error != nil && len(calls) == 0 { + turn.StopReason = StopReasonError + } + return turn +} + +func BuildUsage(model, prompt, thinking, text string, refFileTokens int) Usage { + inputTokens := util.CountPromptTokens(prompt, model) + refFileTokens + reasoningTokens := util.CountOutputTokens(thinking, model) + outputTokens := reasoningTokens + util.CountOutputTokens(text, model) + return Usage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + ReasoningTokens: reasoningTokens, + TotalTokens: inputTokens + outputTokens, + } +} + +func ValidateTurn(turn Turn, policy promptcompat.ToolChoicePolicy) *OutputError { + if policy.IsRequired() && len(turn.ToolCalls) == 0 { + return &OutputError{ + Status: http.StatusUnprocessableEntity, + Message: "tool_choice requires at least one valid tool call.", + Code: "tool_choice_violation", + } + } + if len(turn.ToolCalls) > 0 { + return nil + } + if strings.TrimSpace(turn.Text) != "" { + return nil + } + status, message, code := UpstreamEmptyOutputDetail(turn.ContentFilter, turn.Text, turn.Thinking) + return &OutputError{Status: status, Message: message, Code: code} +} + +func UpstreamEmptyOutputDetail(contentFilter bool, text, thinking string) (int, string, string) { + _ = text + if contentFilter { + return http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter" + } + if strings.TrimSpace(thinking) != "" { + return http.StatusTooManyRequests, "Upstream account hit a rate limit and returned reasoning without visible output.", "upstream_empty_output" + } + return http.StatusTooManyRequests, "Upstream account hit a rate limit and returned empty output.", "upstream_empty_output" +} + +func ShouldRetryEmptyOutput(turn Turn, attempts, maxAttempts int) bool { + return attempts < maxAttempts && + !turn.ContentFilter && + len(turn.ToolCalls) == 0 && + strings.TrimSpace(turn.Text) == "" && + strings.TrimSpace(turn.Thinking) == "" +} + +func FinishReason(turn Turn) string { + switch turn.StopReason { + case StopReasonToolCalls: + return "tool_calls" + case StopReasonContentFilter: + return "content_filter" + default: + return "stop" + } +} diff --git a/internal/assistantturn/turn_test.go b/internal/assistantturn/turn_test.go new file mode 100644 index 0000000..4aca558 --- /dev/null +++ b/internal/assistantturn/turn_test.go @@ -0,0 +1,100 @@ +package assistantturn + +import ( + "testing" + + "ds2api/internal/promptcompat" + "ds2api/internal/sse" +) + +func TestBuildTurnFromCollectedTextCitation(t *testing.T) { + turn := BuildTurnFromCollected(sse.CollectResult{ + Text: "See [citation:1]", + CitationLinks: map[int]string{1: "https://example.com"}, + }, BuildOptions{Model: "deepseek-v4-flash", Prompt: "prompt", SearchEnabled: true, StripReferenceMarkers: true}) + if turn.Text != "See [1](https://example.com)" { + t.Fatalf("text mismatch: %q", turn.Text) + } + if turn.StopReason != StopReasonStop { + t.Fatalf("stop reason mismatch: %q", turn.StopReason) + } + if turn.Error != nil { + t.Fatalf("unexpected error: %#v", turn.Error) + } +} + +func TestBuildTurnFromCollectedToolCall(t *testing.T) { + turn := BuildTurnFromCollected(sse.CollectResult{ + Text: `{"x":1}`, + }, BuildOptions{ + ToolNames: []string{"Write"}, + ToolsRaw: []any{map[string]any{ + "name": "Write", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + }, + }, + }}, + }) + if len(turn.ToolCalls) != 1 { + t.Fatalf("expected one tool call, got %d", len(turn.ToolCalls)) + } + if turn.StopReason != StopReasonToolCalls { + t.Fatalf("stop reason mismatch: %q", turn.StopReason) + } + if _, ok := turn.ToolCalls[0].Input["content"].(string); !ok { + t.Fatalf("expected content coerced to string, got %#v", turn.ToolCalls[0].Input["content"]) + } +} + +func TestBuildTurnFromCollectedThinkingOnlyIsEmptyOutput(t *testing.T) { + turn := BuildTurnFromCollected(sse.CollectResult{Thinking: "hidden"}, BuildOptions{}) + if turn.Error == nil || turn.Error.Code != "upstream_empty_output" { + t.Fatalf("expected empty output error, got %#v", turn.Error) + } +} + +func TestBuildTurnFromCollectedToolChoiceRequired(t *testing.T) { + turn := BuildTurnFromCollected(sse.CollectResult{Text: "hello"}, BuildOptions{ + ToolChoice: promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired}, + }) + if turn.Error == nil || turn.Error.Code != "tool_choice_violation" { + t.Fatalf("expected tool choice violation, got %#v", turn.Error) + } +} + +func TestBuildTurnFromStreamSnapshotUsesVisibleTextAndRawToolDetection(t *testing.T) { + turn := BuildTurnFromStreamSnapshot(StreamSnapshot{ + RawText: `{"x":1}`, + VisibleText: "", + }, BuildOptions{ + ToolNames: []string{"Write"}, + ToolsRaw: []any{map[string]any{ + "name": "Write", + "schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + }, + }, + }}, + }) + if len(turn.ToolCalls) != 1 { + t.Fatalf("expected stream snapshot tool call, got %d", len(turn.ToolCalls)) + } + if _, ok := turn.ToolCalls[0].Input["content"].(string); !ok { + t.Fatalf("expected stream snapshot schema coercion, got %#v", turn.ToolCalls[0].Input["content"]) + } +} + +func TestBuildTurnFromStreamSnapshotAlreadyEmittedToolAvoidsEmptyError(t *testing.T) { + turn := BuildTurnFromStreamSnapshot(StreamSnapshot{AlreadyEmittedCalls: true}, BuildOptions{}) + if turn.Error != nil { + t.Fatalf("unexpected empty-output error after emitted tool call: %#v", turn.Error) + } + if turn.StopReason != StopReasonToolCalls { + t.Fatalf("stop reason mismatch: %q", turn.StopReason) + } +} diff --git a/internal/completionruntime/nonstream.go b/internal/completionruntime/nonstream.go new file mode 100644 index 0000000..1b32969 --- /dev/null +++ b/internal/completionruntime/nonstream.go @@ -0,0 +1,170 @@ +package completionruntime + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + "ds2api/internal/assistantturn" + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/httpapi/openai/shared" + "ds2api/internal/promptcompat" + "ds2api/internal/sse" +) + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type Options struct { + StripReferenceMarkers bool + MaxAttempts int + RetryEnabled bool + RetryMaxAttempts int +} + +type NonStreamResult struct { + SessionID string + Payload map[string]any + Turn assistantturn.Turn + Attempts int +} + +type StartResult struct { + SessionID string + Payload map[string]any + Pow string + Response *http.Response +} + +func StartCompletion(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (StartResult, *assistantturn.OutputError) { + maxAttempts := opts.MaxAttempts + if maxAttempts <= 0 { + maxAttempts = 3 + } + sessionID, err := ds.CreateSession(ctx, a, maxAttempts) + if err != nil { + return StartResult{}, authOutputError(a) + } + pow, err := ds.GetPow(ctx, a, maxAttempts) + if err != nil { + return StartResult{SessionID: sessionID}, &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Failed to get PoW (invalid token or unknown error).", Code: "error"} + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := ds.CallCompletion(ctx, a, payload, pow, maxAttempts) + if err != nil { + return StartResult{SessionID: sessionID, Payload: payload, Pow: pow}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"} + } + return StartResult{SessionID: sessionID, Payload: payload, Pow: pow, Response: resp}, nil +} + +func ExecuteNonStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (NonStreamResult, *assistantturn.OutputError) { + start, startErr := StartCompletion(ctx, ds, a, stdReq, opts) + if startErr != nil { + return NonStreamResult{SessionID: start.SessionID, Payload: start.Payload}, startErr + } + maxAttempts := opts.MaxAttempts + if maxAttempts <= 0 { + maxAttempts = 3 + } + sessionID := start.SessionID + payload := start.Payload + pow := start.Pow + + attempts := 0 + currentResp := start.Response + usagePrompt := stdReq.PromptTokenText + accumulatedThinking := "" + accumulatedRawThinking := "" + accumulatedToolDetectionThinking := "" + for { + turn, outErr := collectAttempt(currentResp, stdReq, usagePrompt, opts) + if outErr != nil { + return NonStreamResult{SessionID: sessionID, Payload: payload, Attempts: attempts}, outErr + } + accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, turn.Thinking) + accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, turn.RawThinking) + accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, turn.DetectionThinking) + turn.Thinking = accumulatedThinking + turn.RawThinking = accumulatedRawThinking + turn.DetectionThinking = accumulatedToolDetectionThinking + turn = assistantturn.BuildTurnFromCollected(sse.CollectResult{ + Text: turn.RawText, + Thinking: turn.RawThinking, + ToolDetectionThinking: turn.DetectionThinking, + ContentFilter: turn.ContentFilter, + CitationLinks: turn.CitationLinks, + ResponseMessageID: turn.ResponseMessageID, + }, buildOptions(stdReq, usagePrompt, opts)) + + retryMax := opts.RetryMaxAttempts + if retryMax <= 0 { + retryMax = shared.EmptyOutputRetryMaxAttempts() + } + if !opts.RetryEnabled || !assistantturn.ShouldRetryEmptyOutput(turn, attempts, retryMax) { + return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, turn.Error + } + + attempts++ + config.Logger.Info("[completion_runtime_empty_retry] attempting synthetic retry", "surface", stdReq.Surface, "stream", false, "retry_attempt", attempts, "parent_message_id", turn.ResponseMessageID) + retryPow, powErr := ds.GetPow(ctx, a, maxAttempts) + if powErr != nil { + config.Logger.Warn("[completion_runtime_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", stdReq.Surface, "retry_attempt", attempts, "error", powErr) + retryPow = pow + } + retryPayload := shared.ClonePayloadForEmptyOutputRetry(payload, turn.ResponseMessageID) + nextResp, err := ds.CallCompletion(ctx, a, retryPayload, retryPow, maxAttempts) + if err != nil { + return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"} + } + usagePrompt = shared.UsagePromptWithEmptyOutputRetry(usagePrompt, attempts) + currentResp = nextResp + } +} + +func collectAttempt(resp *http.Response, stdReq promptcompat.StandardRequest, usagePrompt string, opts Options) (assistantturn.Turn, *assistantturn.OutputError) { + defer func() { + if err := resp.Body.Close(); err != nil { + config.Logger.Warn("[completion_runtime] response body close failed", "surface", stdReq.Surface, "error", err) + } + }() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + message := strings.TrimSpace(string(body)) + if message == "" { + message = http.StatusText(resp.StatusCode) + } + return assistantturn.Turn{}, &assistantturn.OutputError{Status: resp.StatusCode, Message: message, Code: "error"} + } + result := sse.CollectStream(resp, stdReq.Thinking, false) + return assistantturn.BuildTurnFromCollected(result, buildOptions(stdReq, usagePrompt, opts)), nil +} + +func buildOptions(stdReq promptcompat.StandardRequest, prompt string, opts Options) assistantturn.BuildOptions { + return assistantturn.BuildOptions{ + Model: stdReq.ResponseModel, + Prompt: prompt, + RefFileTokens: stdReq.RefFileTokens, + SearchEnabled: stdReq.Search, + StripReferenceMarkers: opts.StripReferenceMarkers, + ToolNames: stdReq.ToolNames, + ToolsRaw: stdReq.ToolsRaw, + ToolChoice: stdReq.ToolChoice, + } +} + +func authOutputError(a *auth.RequestAuth) *assistantturn.OutputError { + if a != nil && a.UseConfigToken { + return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Account token is invalid. Please re-login the account in admin.", Code: "error"} + } + return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Invalid token. If this should be a DS2API key, add it to config.keys first.", Code: "error"} +} + +func Errorf(status int, format string, args ...any) *assistantturn.OutputError { + return &assistantturn.OutputError{Status: status, Message: fmt.Sprintf(format, args...), Code: "error"} +} diff --git a/internal/completionruntime/nonstream_test.go b/internal/completionruntime/nonstream_test.go new file mode 100644 index 0000000..1428fca --- /dev/null +++ b/internal/completionruntime/nonstream_test.go @@ -0,0 +1,120 @@ +package completionruntime + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "ds2api/internal/auth" + "ds2api/internal/promptcompat" +) + +type fakeDeepSeekCaller struct { + responses []*http.Response + payloads []map[string]any +} + +func (f *fakeDeepSeekCaller) CreateSession(context.Context, *auth.RequestAuth, int) (string, error) { + return "session-1", nil +} + +func (f *fakeDeepSeekCaller) GetPow(context.Context, *auth.RequestAuth, int) (string, error) { + return "pow", nil +} + +func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, _ *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) { + f.payloads = append(f.payloads, payload) + if len(f.responses) == 0 { + return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil + } + resp := f.responses[0] + f.responses = f.responses[1:] + return resp, nil +} + +func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) { + ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse( + http.StatusOK, + `data: {"response_message_id":42,"p":"response/content","v":"{\"x\":1}"}`, + )}} + stdReq := promptcompat.StandardRequest{ + Surface: "test", + ResponseModel: "deepseek-v4-flash", + PromptTokenText: "prompt", + FinalPrompt: "final prompt", + ToolNames: []string{"Write"}, + ToolsRaw: []any{map[string]any{ + "name": "Write", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + }, + }, + }}, + } + + result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{}) + if outErr != nil { + t.Fatalf("unexpected output error: %#v", outErr) + } + if result.SessionID != "session-1" { + t.Fatalf("session mismatch: %q", result.SessionID) + } + if got := result.Turn.ResponseMessageID; got != 42 { + t.Fatalf("response message id mismatch: %d", got) + } + if len(result.Turn.ToolCalls) != 1 { + t.Fatalf("expected one tool call, got %d", len(result.Turn.ToolCalls)) + } + if _, ok := result.Turn.ToolCalls[0].Input["content"].(string); !ok { + t.Fatalf("expected schema-normalized string argument, got %#v", result.Turn.ToolCalls[0].Input["content"]) + } + if result.Turn.Usage.InputTokens == 0 || result.Turn.Usage.TotalTokens == 0 { + t.Fatalf("expected usage to be populated, got %#v", result.Turn.Usage) + } +} + +func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) { + ds := &fakeDeepSeekCaller{responses: []*http.Response{ + sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/status","v":"FINISHED"}`), + sseHTTPResponse(http.StatusOK, `data: {"response_message_id":78,"p":"response/content","v":"ok"}`), + }} + stdReq := promptcompat.StandardRequest{ + Surface: "test", + ResponseModel: "deepseek-v4-flash", + PromptTokenText: "prompt", + FinalPrompt: "final prompt", + } + + result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{RetryEnabled: true}) + if outErr != nil { + t.Fatalf("unexpected output error: %#v", outErr) + } + if result.Attempts != 1 { + t.Fatalf("expected one retry, got %d", result.Attempts) + } + if len(ds.payloads) != 2 { + t.Fatalf("expected two completion calls, got %d", len(ds.payloads)) + } + if got := ds.payloads[1]["parent_message_id"]; got != 77 { + t.Fatalf("retry parent_message_id mismatch: %#v", got) + } + if result.Turn.Text != "ok" { + t.Fatalf("retry text mismatch: %q", result.Turn.Text) + } +} + +func sseHTTPResponse(status int, lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go index 694f5fd..3912f41 100644 --- a/internal/format/claude/render.go +++ b/internal/format/claude/render.go @@ -1,6 +1,7 @@ package claude import ( + "ds2api/internal/assistantturn" "ds2api/internal/toolcall" "fmt" "time" @@ -9,6 +10,47 @@ import ( "ds2api/internal/util" ) +func BuildMessageResponseFromTurn(messageID, model string, turn assistantturn.Turn, exposeThinking bool) map[string]any { + content := make([]map[string]any, 0, 4) + if exposeThinking && turn.Thinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": turn.Thinking}) + } + stopReason := "end_turn" + if len(turn.ToolCalls) > 0 { + stopReason = "tool_use" + for i, tc := range turn.ToolCalls { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + text := turn.Text + if text == "" && exposeThinking { + text = turn.Thinking + } + if text == "" { + text = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": text}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": turn.Usage.InputTokens, + "output_tokens": turn.Usage.OutputTokens, + }, + } +} + func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { detected := toolcall.ParseToolCalls(finalText, toolNames) if len(detected) == 0 && finalText == "" && finalThinking != "" { diff --git a/internal/httpapi/claude/handler_messages.go b/internal/httpapi/claude/handler_messages.go index ed66475..6a202ad 100644 --- a/internal/httpapi/claude/handler_messages.go +++ b/internal/httpapi/claude/handler_messages.go @@ -4,13 +4,19 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "net/http/httptest" "strings" + "time" + "ds2api/internal/auth" + "ds2api/internal/completionruntime" "ds2api/internal/config" + claudefmt "ds2api/internal/format/claude" "ds2api/internal/httpapi/requestbody" + "ds2api/internal/promptcompat" streamengine "ds2api/internal/stream" "ds2api/internal/translatorcliproxy" "ds2api/internal/util" @@ -22,14 +28,90 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { r.Header.Set("anthropic-version", "2023-06-01") } - if h.OpenAI == nil { - writeClaudeError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.") + if isClaudeVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, h.Store) { return } - if h.proxyViaOpenAI(w, r, h.Store) { + if h.Auth == nil || h.DS == nil { + if h.OpenAI != nil && h.proxyViaOpenAI(w, r, h.Store) { + return + } + writeClaudeError(w, http.StatusInternalServerError, "Claude runtime backend unavailable.") return } - writeClaudeError(w, http.StatusBadGateway, "Failed to proxy Claude request.") + if h.handleClaudeDirect(w, r) { + return + } + writeClaudeError(w, http.StatusBadGateway, "Failed to handle Claude request.") +} + +func isClaudeVercelProxyRequest(r *http.Request) bool { + if r == nil || r.URL == nil { + return false + } + return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" || + strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1" +} + +func (h *Handler) handleClaudeDirect(w http.ResponseWriter, r *http.Request) bool { + raw, err := io.ReadAll(r.Body) + if err != nil { + if errors.Is(err, requestbody.ErrInvalidUTF8Body) { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + } else { + writeClaudeError(w, http.StatusBadRequest, "invalid body") + } + return true + } + var req map[string]any + if err := json.Unmarshal(raw, &req); err != nil { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + return true + } + exposeThinking := false + if enabled, ok := util.ResolveThinkingOverride(req); ok && enabled { + exposeThinking = true + } else if _, ok := util.ResolveThinkingOverride(req); !ok && !util.ToBool(req["stream"]) { + req["thinking"] = map[string]any{"type": "enabled"} + } + norm, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) + return true + } + a, err := h.Auth.Determine(r) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, err.Error()) + return true + } + defer h.Auth.Release(a) + if norm.Standard.Stream { + h.handleClaudeDirectStream(w, r, a, norm.Standard) + return true + } + result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, norm.Standard, completionruntime.Options{ + StripReferenceMarkers: h.compatStripReferenceMarkers(), + RetryEnabled: true, + }) + if outErr != nil { + writeClaudeError(w, outErr.Status, outErr.Message) + return true + } + writeJSON(w, http.StatusOK, claudefmt.BuildMessageResponseFromTurn( + fmt.Sprintf("msg_%d", time.Now().UnixNano()), + norm.Standard.ResponseModel, + result.Turn, + exposeThinking, + )) + return true +} + +func (h *Handler) handleClaudeDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) { + start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{}) + if outErr != nil { + writeClaudeError(w, outErr.Status, outErr.Message) + return + } + h.handleClaudeStreamRealtime(w, r, start.Response, stdReq.ResponseModel, stdReq.Messages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw) } func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool { diff --git a/internal/httpapi/claude/stream_runtime_finalize.go b/internal/httpapi/claude/stream_runtime_finalize.go index 757cfa9..28e276f 100644 --- a/internal/httpapi/claude/stream_runtime_finalize.go +++ b/internal/httpapi/claude/stream_runtime_finalize.go @@ -1,6 +1,7 @@ package claude import ( + "ds2api/internal/assistantturn" "ds2api/internal/sse" "ds2api/internal/toolcall" "ds2api/internal/toolstream" @@ -9,7 +10,6 @@ import ( "time" streamengine "ds2api/internal/stream" - "ds2api/internal/util" ) func (s *claudeStreamRuntime) closeThinkingBlock() { @@ -115,18 +115,28 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { s.closeTextBlock() - finalThinking := s.thinking.String() - finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers) + turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{ + RawText: s.rawText.String(), + VisibleText: s.text.String(), + RawThinking: s.rawThinking.String(), + VisibleThinking: s.thinking.String(), + DetectionThinking: s.toolDetectionThinking.String(), + AlreadyEmittedCalls: s.toolCallsDetected, + AlreadyEmittedToolRaw: s.toolCallsDetected, + }, assistantturn.BuildOptions{ + Model: s.model, + Prompt: s.promptTokenText, + SearchEnabled: s.searchEnabled, + StripReferenceMarkers: s.stripReferenceMarkers, + ToolNames: s.toolNames, + ToolsRaw: s.toolsRaw, + }) + finalText := turn.Text if s.bufferToolContent && !s.toolCallsDetected { - detected := toolcall.ParseStandaloneToolCallsDetailed(s.rawText.String(), s.toolNames) - if len(detected.Calls) == 0 { - detected = toolcall.ParseStandaloneToolCallsDetailed(s.rawThinking.String(), s.toolNames) - } - if len(detected.Calls) > 0 { - normalized := toolcall.NormalizeParsedToolCallsForSchemas(detected.Calls, s.toolsRaw) + if len(turn.ToolCalls) > 0 { stopReason = "tool_use" - for _, tc := range normalized { + for _, tc := range turn.ToolCalls { idx := s.nextBlockIndex s.nextBlockIndex++ s.sendToolUseBlock(idx, tc) @@ -161,7 +171,6 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { stopReason = "tool_use" } - outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model) s.send("message_delta", map[string]any{ "type": "message_delta", "delta": map[string]any{ @@ -169,7 +178,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { "stop_sequence": nil, }, "usage": map[string]any{ - "output_tokens": outputTokens, + "output_tokens": turn.Usage.OutputTokens, }, }) s.send("message_stop", map[string]any{"type": "message_stop"}) diff --git a/internal/httpapi/gemini/convert_request.go b/internal/httpapi/gemini/convert_request.go index 43697e7..4ad5238 100644 --- a/internal/httpapi/gemini/convert_request.go +++ b/internal/httpapi/gemini/convert_request.go @@ -33,6 +33,9 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin toolsRaw := convertGeminiTools(req["tools"]) finalPrompt, toolNames := promptcompat.BuildOpenAIPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled) + if len(toolNames) == 0 && len(toolsRaw) > 0 { + toolNames = []string{"__any_tool__"} + } passThrough := collectGeminiPassThrough(req) return promptcompat.StandardRequest{ @@ -42,6 +45,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin ResponseModel: requestedModel, Messages: messagesRaw, PromptTokenText: finalPrompt, + ToolsRaw: toolsRaw, FinalPrompt: finalPrompt, ToolNames: toolNames, Stream: stream, diff --git a/internal/httpapi/gemini/handler_generate.go b/internal/httpapi/gemini/handler_generate.go index 085a29c..9661ab9 100644 --- a/internal/httpapi/gemini/handler_generate.go +++ b/internal/httpapi/gemini/handler_generate.go @@ -11,7 +11,11 @@ import ( "github.com/go-chi/chi/v5" + "ds2api/internal/assistantturn" + "ds2api/internal/auth" + "ds2api/internal/completionruntime" "ds2api/internal/httpapi/requestbody" + "ds2api/internal/promptcompat" "ds2api/internal/sse" "ds2api/internal/toolcall" "ds2api/internal/translatorcliproxy" @@ -21,14 +25,80 @@ import ( ) func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) { - if h.OpenAI == nil { - writeGeminiError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.") + if isGeminiVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, stream) { return } - if h.proxyViaOpenAI(w, r, stream) { + if h.Auth == nil || h.DS == nil { + if h.OpenAI != nil && h.proxyViaOpenAI(w, r, stream) { + return + } + writeGeminiError(w, http.StatusInternalServerError, "Gemini runtime backend unavailable.") return } - writeGeminiError(w, http.StatusBadGateway, "Failed to proxy Gemini request.") + if h.handleGeminiDirect(w, r, stream) { + return + } + writeGeminiError(w, http.StatusBadGateway, "Failed to handle Gemini request.") +} + +func isGeminiVercelProxyRequest(r *http.Request) bool { + if r == nil || r.URL == nil { + return false + } + return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" || + strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1" +} + +func (h *Handler) handleGeminiDirect(w http.ResponseWriter, r *http.Request, stream bool) bool { + raw, err := io.ReadAll(r.Body) + if err != nil { + if errors.Is(err, requestbody.ErrInvalidUTF8Body) { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + } else { + writeGeminiError(w, http.StatusBadRequest, "invalid body") + } + return true + } + routeModel := strings.TrimSpace(chi.URLParam(r, "model")) + var req map[string]any + if err := json.Unmarshal(raw, &req); err != nil { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + return true + } + stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream) + if err != nil { + writeGeminiError(w, http.StatusBadRequest, err.Error()) + return true + } + a, err := h.Auth.Determine(r) + if err != nil { + writeGeminiError(w, http.StatusUnauthorized, err.Error()) + return true + } + defer h.Auth.Release(a) + if stream { + h.handleGeminiDirectStream(w, r, a, stdReq) + return true + } + result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{ + StripReferenceMarkers: h.compatStripReferenceMarkers(), + RetryEnabled: true, + }) + if outErr != nil { + writeGeminiError(w, outErr.Status, outErr.Message) + return true + } + writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponseFromTurn(result.Turn)) + return true +} + +func (h *Handler) handleGeminiDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) { + start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{}) + if outErr != nil { + writeGeminiError(w, outErr.Status, outErr.Message) + return + } + h.handleStreamGenerateContent(w, r, start.Response, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw) } func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool { @@ -250,6 +320,48 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final } } +func buildGeminiGenerateContentResponseFromTurn(turn assistantturn.Turn) map[string]any { + parts := buildGeminiPartsFromTurn(turn) + return map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + "finishReason": "STOP", + }, + }, + "modelVersion": turn.Model, + "usageMetadata": map[string]any{ + "promptTokenCount": turn.Usage.InputTokens, + "candidatesTokenCount": turn.Usage.OutputTokens, + "totalTokenCount": turn.Usage.TotalTokens, + }, + } +} + +func buildGeminiPartsFromTurn(turn assistantturn.Turn) []map[string]any { + if len(turn.ToolCalls) > 0 { + parts := make([]map[string]any, 0, len(turn.ToolCalls)) + for _, tc := range turn.ToolCalls { + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": tc.Name, + "args": tc.Input, + }, + }) + } + return parts + } + text := turn.Text + if text == "" { + text = turn.Thinking + } + return []map[string]any{{"text": text}} +} + //nolint:unused // retained for native Gemini non-stream handling path. func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any { promptTokens := util.CountPromptTokens(finalPrompt, model) diff --git a/internal/httpapi/gemini/handler_stream_runtime.go b/internal/httpapi/gemini/handler_stream_runtime.go index ba76335..ee00106 100644 --- a/internal/httpapi/gemini/handler_stream_runtime.go +++ b/internal/httpapi/gemini/handler_stream_runtime.go @@ -7,13 +7,14 @@ import ( "strings" "time" + "ds2api/internal/assistantturn" dsprotocol "ds2api/internal/deepseek/protocol" "ds2api/internal/sse" streamengine "ds2api/internal/stream" ) //nolint:unused // retained for native Gemini stream handling path. -func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -28,7 +29,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req rc := http.NewResponseController(w) _, canFlush := w.(http.Flusher) - runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames) + runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw) initialType := "text" if thinkingEnabled { @@ -64,9 +65,11 @@ type geminiStreamRuntime struct { bufferContent bool stripReferenceMarkers bool toolNames []string + toolsRaw any - thinking strings.Builder - text strings.Builder + accumulator *assistantturn.Accumulator + contentFilter bool + responseMessageID int } //nolint:unused // retained for native Gemini stream handling path. @@ -80,6 +83,7 @@ func newGeminiStreamRuntime( searchEnabled bool, stripReferenceMarkers bool, toolNames []string, + toolsRaw any, ) *geminiStreamRuntime { return &geminiStreamRuntime{ w: w, @@ -92,6 +96,12 @@ func newGeminiStreamRuntime( bufferContent: len(toolNames) > 0, stripReferenceMarkers: stripReferenceMarkers, toolNames: toolNames, + toolsRaw: toolsRaw, + accumulator: assistantturn.NewAccumulator(assistantturn.AccumulatorOptions{ + ThinkingEnabled: thinkingEnabled, + SearchEnabled: searchEnabled, + StripReferenceMarkers: stripReferenceMarkers, + }), } } @@ -111,32 +121,24 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.ResponseMessageID > 0 { + s.responseMessageID = parsed.ResponseMessageID + } if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + if parsed.ContentFilter { + s.contentFilter = true + } return streamengine.ParsedDecision{Stop: true} } - contentSeen := false - for _, p := range parsed.Parts { - cleanedText := cleanVisibleOutput(p.Text, s.stripReferenceMarkers) - if cleanedText == "" { - continue - } - if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(cleanedText) { - continue - } - contentSeen = true + accumulated := s.accumulator.Apply(parsed) + for _, p := range accumulated.Parts { if p.Type == "thinking" { - if s.thinkingEnabled { - if cleanedText != "" { - s.thinking.WriteString(cleanedText) - } - } continue } - if cleanedText == "" { + if p.RawText == "" || p.CitationOnly || p.VisibleText == "" { continue } - s.text.WriteString(cleanedText) if s.bufferContent { continue } @@ -146,23 +148,38 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse "index": 0, "content": map[string]any{ "role": "model", - "parts": []map[string]any{{"text": cleanedText}}, + "parts": []map[string]any{{"text": p.VisibleText}}, }, }, }, "modelVersion": s.model, }) } - return streamengine.ParsedDecision{ContentSeen: contentSeen} + return streamengine.ParsedDecision{ContentSeen: accumulated.ContentSeen} } //nolint:unused // retained for native Gemini stream handling path. func (s *geminiStreamRuntime) finalize() { - finalThinking := s.thinking.String() - finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers) + rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot() + turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{ + RawText: rawText, + VisibleText: text, + RawThinking: rawThinking, + VisibleThinking: thinking, + DetectionThinking: detectionThinking, + ContentFilter: s.contentFilter, + ResponseMessageID: s.responseMessageID, + }, assistantturn.BuildOptions{ + Model: s.model, + Prompt: s.finalPrompt, + SearchEnabled: s.searchEnabled, + StripReferenceMarkers: s.stripReferenceMarkers, + ToolNames: s.toolNames, + ToolsRaw: s.toolsRaw, + }) if s.bufferContent { - parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames) + parts := buildGeminiPartsFromTurn(turn) s.sendChunk(map[string]any{ "candidates": []map[string]any{ { @@ -190,7 +207,11 @@ func (s *geminiStreamRuntime) finalize() { "finishReason": "STOP", }, }, - "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText), + "modelVersion": s.model, + "usageMetadata": map[string]any{ + "promptTokenCount": turn.Usage.InputTokens, + "candidatesTokenCount": turn.Usage.OutputTokens, + "totalTokenCount": turn.Usage.TotalTokens, + }, }) } diff --git a/internal/httpapi/openai/chat/chat_stream_runtime.go b/internal/httpapi/openai/chat/chat_stream_runtime.go index ed5034f..14183ae 100644 --- a/internal/httpapi/openai/chat/chat_stream_runtime.go +++ b/internal/httpapi/openai/chat/chat_stream_runtime.go @@ -5,8 +5,10 @@ import ( "net/http" "strings" + "ds2api/internal/assistantturn" openaifmt "ds2api/internal/format/openai" "ds2api/internal/httpapi/openai/shared" + "ds2api/internal/promptcompat" "ds2api/internal/sse" streamengine "ds2api/internal/stream" "ds2api/internal/toolstream" @@ -24,6 +26,7 @@ type chatStreamRuntime struct { refFileTokens int toolNames []string toolsRaw any + toolChoice promptcompat.ToolChoicePolicy thinkingEnabled bool searchEnabled bool @@ -89,6 +92,7 @@ func newChatStreamRuntime( stripReferenceMarkers bool, toolNames []string, toolsRaw any, + toolChoice promptcompat.ToolChoicePolicy, bufferToolContent bool, emitEarlyToolDeltas bool, ) *chatStreamRuntime { @@ -102,6 +106,7 @@ func newChatStreamRuntime( finalPrompt: finalPrompt, toolNames: toolNames, toolsRaw: toolsRaw, + toolChoice: toolChoice, thinkingEnabled: thinkingEnabled, searchEnabled: searchEnabled, stripReferenceMarkers: stripReferenceMarkers, @@ -201,14 +206,33 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) s.finalErrorCode = "" finalThinking := s.accumulator.Thinking.String() finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String() - finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers) - s.finalThinking = finalThinking - s.finalText = finalText - detected := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames) - if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted { + finalText := s.accumulator.Text.String() + turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{ + RawText: s.accumulator.RawText.String(), + VisibleText: finalText, + RawThinking: s.accumulator.RawThinking.String(), + VisibleThinking: finalThinking, + DetectionThinking: finalToolDetectionThinking, + ContentFilter: finishReason == "content_filter", + ResponseMessageID: s.responseMessageID, + AlreadyEmittedCalls: s.toolCallsEmitted, + AlreadyEmittedToolRaw: s.toolCallsDoneEmitted, + }, assistantturn.BuildOptions{ + Model: s.model, + Prompt: s.finalPrompt, + RefFileTokens: s.refFileTokens, + SearchEnabled: s.searchEnabled, + StripReferenceMarkers: s.stripReferenceMarkers, + ToolNames: s.toolNames, + ToolsRaw: s.toolsRaw, + ToolChoice: s.toolChoice, + }) + s.finalThinking = turn.Thinking + s.finalText = turn.Text + if len(turn.ToolCalls) > 0 && !s.toolCallsDoneEmitted { finishReason = "tool_calls" s.sendDelta(map[string]any{ - "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw), + "tool_calls": formatFinalStreamToolCallsWithStableIDs(turn.ToolCalls, s.streamToolCallIDs, s.toolsRaw), }) s.toolCallsEmitted = true s.toolCallsDoneEmitted = true @@ -237,11 +261,14 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) batch.flush() } - if len(detected.Calls) > 0 || s.toolCallsEmitted { + if len(turn.ToolCalls) > 0 || s.toolCallsEmitted { finishReason = "tool_calls" } - if len(detected.Calls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(finalText) == "" { - status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking) + if len(turn.ToolCalls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(turn.Text) == "" { + status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking) + if turn.Error != nil { + status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code + } if deferEmptyOutput { s.finalErrorStatus = status s.finalErrorMessage = message @@ -251,7 +278,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) s.sendFailedChunk(status, message, code) return true } - usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText, s.refFileTokens) + usage := chatUsageFromTurn(turn) s.finalFinishReason = finishReason s.finalUsage = usage s.sendChunk(openaifmt.BuildChatStreamChunk( @@ -265,6 +292,17 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) return true } +func chatUsageFromTurn(turn assistantturn.Turn) map[string]any { + return map[string]any{ + "prompt_tokens": turn.Usage.InputTokens, + "completion_tokens": turn.Usage.OutputTokens, + "total_tokens": turn.Usage.TotalTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": turn.Usage.ReasoningTokens, + }, + } +} + func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { if !parsed.Parsed { return streamengine.ParsedDecision{} diff --git a/internal/httpapi/openai/chat/chat_stream_runtime_test.go b/internal/httpapi/openai/chat/chat_stream_runtime_test.go index db3026f..022e535 100644 --- a/internal/httpapi/openai/chat/chat_stream_runtime_test.go +++ b/internal/httpapi/openai/chat/chat_stream_runtime_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" "time" + + "ds2api/internal/promptcompat" ) func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) { @@ -23,6 +25,7 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) { true, nil, nil, + promptcompat.DefaultToolChoicePolicy(), false, false, ) @@ -51,3 +54,34 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) { t.Fatalf("expected empty choices heartbeat, got %#v", choices) } } + +func TestChatStreamFinalizeEnforcesRequiredToolChoice(t *testing.T) { + rec := httptest.NewRecorder() + runtime := newChatStreamRuntime( + rec, + http.NewResponseController(rec), + true, + "chatcmpl-test", + time.Now().Unix(), + "deepseek-v4-flash", + "prompt", + false, + false, + true, + []string{"Write"}, + nil, + promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired}, + true, + false, + ) + + if !runtime.finalize("stop", false) { + t.Fatalf("expected terminal error to be written") + } + if runtime.finalErrorCode != "tool_choice_violation" { + t.Fatalf("expected tool_choice_violation, got %q body=%s", runtime.finalErrorCode, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "tool_choice requires") { + t.Fatalf("expected tool choice error in stream body, got %s", rec.Body.String()) + } +} diff --git a/internal/httpapi/openai/chat/empty_retry_runtime.go b/internal/httpapi/openai/chat/empty_retry_runtime.go index 68a570d..54ce00e 100644 --- a/internal/httpapi/openai/chat/empty_retry_runtime.go +++ b/internal/httpapi/openai/chat/empty_retry_runtime.go @@ -7,10 +7,12 @@ import ( "strings" "time" + "ds2api/internal/assistantturn" "ds2api/internal/auth" "ds2api/internal/config" dsprotocol "ds2api/internal/deepseek/protocol" openaifmt "ds2api/internal/format/openai" + "ds2api/internal/promptcompat" "ds2api/internal/sse" streamengine "ds2api/internal/stream" ) @@ -26,6 +28,7 @@ type chatNonStreamResult struct { body map[string]any finishReason string responseMessageID int + outputError *assistantturn.OutputError } func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) { @@ -86,35 +89,40 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http. return chatNonStreamResult{}, false } result := sse.CollectStream(resp, thinkingEnabled, true) - stripReferenceMarkers := h.compatStripReferenceMarkers() - finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers) - finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers) - if searchEnabled { - finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks) - } - detected := detectAssistantToolCalls(result.Text, finalText, result.Thinking, result.ToolDetectionThinking, toolNames) - respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw) + turn := assistantturn.BuildTurnFromCollected(result, assistantturn.BuildOptions{ + Model: model, + Prompt: usagePrompt, + SearchEnabled: searchEnabled, + StripReferenceMarkers: h.compatStripReferenceMarkers(), + ToolNames: toolNames, + ToolsRaw: toolsRaw, + }) + respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, turn.Thinking, turn.Text, turn.ToolCalls, toolsRaw) return chatNonStreamResult{ rawThinking: result.Thinking, rawText: result.Text, - thinking: finalThinking, + thinking: turn.Thinking, toolDetectionThinking: result.ToolDetectionThinking, - text: finalText, + text: turn.Text, contentFilter: result.ContentFilter, - detectedCalls: len(detected.Calls), + detectedCalls: len(turn.ToolCalls), body: respBody, finishReason: chatFinishReason(respBody), responseMessageID: result.ResponseMessageID, + outputError: turn.Error, }, true } func (h *Handler) finishChatNonStreamResult(w http.ResponseWriter, result chatNonStreamResult, attempts int, usagePrompt string, refFileTokens int, historySession *chatHistorySession) { - if result.detectedCalls == 0 && shouldWriteUpstreamEmptyOutputError(result.text, result.thinking) { + if result.detectedCalls == 0 && strings.TrimSpace(result.text) == "" { status, message, code := upstreamEmptyOutputDetail(result.contentFilter, result.text, result.thinking) + if result.outputError != nil { + status, message, code = result.outputError.Status, result.outputError.Message, result.outputError.Code + } if historySession != nil { historySession.error(status, message, code, result.thinking, result.text) } - writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter) + writeOpenAIErrorWithCode(w, status, message, code) config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "chat.completions", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter) return } @@ -147,8 +155,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool { strings.TrimSpace(result.thinking) == "" } -func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) { - streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession) +func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) { + streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, historySession) if !ok { return } @@ -190,7 +198,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, } } -func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) { +func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) { if resp.StatusCode != http.StatusOK { defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) @@ -216,6 +224,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res streamRuntime := newChatStreamRuntime( w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw, + toolChoice, len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(), ) streamRuntime.refFileTokens = refFileTokens diff --git a/internal/httpapi/openai/chat/empty_retry_runtime_test.go b/internal/httpapi/openai/chat/empty_retry_runtime_test.go index ff8155f..9cf5d39 100644 --- a/internal/httpapi/openai/chat/empty_retry_runtime_test.go +++ b/internal/httpapi/openai/chat/empty_retry_runtime_test.go @@ -8,6 +8,7 @@ import ( "time" "ds2api/internal/chathistory" + "ds2api/internal/promptcompat" "ds2api/internal/stream" ) @@ -48,6 +49,7 @@ func TestConsumeChatStreamAttemptMarksContextCancelledState(t *testing.T) { true, nil, nil, + promptcompat.DefaultToolChoicePolicy(), false, false, ) diff --git a/internal/httpapi/openai/chat/handler.go b/internal/httpapi/openai/chat/handler.go index bdb8bdf..522fbcb 100644 --- a/internal/httpapi/openai/chat/handler.go +++ b/internal/httpapi/openai/chat/handler.go @@ -80,6 +80,10 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string) { shared.WriteOpenAIError(w, status, message) } +func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) { + shared.WriteOpenAIErrorWithCode(w, status, message, code) +} + func openAIErrorType(status int) string { return shared.OpenAIErrorType(status) } diff --git a/internal/httpapi/openai/chat/handler_chat.go b/internal/httpapi/openai/chat/handler_chat.go index 0d960ca..bf1d16e 100644 --- a/internal/httpapi/openai/chat/handler_chat.go +++ b/internal/httpapi/openai/chat/handler_chat.go @@ -9,6 +9,7 @@ import ( "time" "ds2api/internal/auth" + "ds2api/internal/completionruntime" "ds2api/internal/config" dsprotocol "ds2api/internal/deepseek/protocol" openaifmt "ds2api/internal/format/openai" @@ -76,44 +77,40 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { } historySession := startChatHistory(h.ChatHistory, r, a, stdReq) - sessionID, err = h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - if a.UseConfigToken { + if !stdReq.Stream { + result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{ + StripReferenceMarkers: h.compatStripReferenceMarkers(), + RetryEnabled: true, + }) + sessionID = result.SessionID + if outErr != nil { if historySession != nil { - historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "") + historySession.error(outErr.Status, outErr.Message, outErr.Code, result.Turn.Thinking, result.Turn.Text) } - writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") - } else { - if historySession != nil { - historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "") - } - writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code) + return } + respBody := openaifmt.BuildChatCompletionWithToolCalls(result.SessionID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw) + respBody["usage"] = chatUsageFromTurn(result.Turn) + finishReason := chatFinishReason(respBody) + if historySession != nil { + historySession.success(http.StatusOK, result.Turn.Thinking, result.Turn.Text, finishReason, chatUsageFromTurn(result.Turn)) + } + writeJSON(w, http.StatusOK, respBody) return } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { + + start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{}) + sessionID = start.SessionID + if outErr != nil { if historySession != nil { - historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "") + historySession.error(outErr.Status, outErr.Message, outErr.Code, "", "") } - writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") - return - } - payload := stdReq.CompletionPayload(sessionID) - resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) - if err != nil { - if historySession != nil { - historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "") - } - writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code) return } refFileTokens := stdReq.RefFileTokens - if stdReq.Stream { - h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession) - return - } - h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession) + h.handleStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, historySession) } func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) { @@ -234,6 +231,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt stripReferenceMarkers, toolNames, toolsRaw, + promptcompat.DefaultToolChoicePolicy(), bufferToolContent, emitEarlyToolDeltas, ) diff --git a/internal/httpapi/openai/responses/empty_retry_runtime.go b/internal/httpapi/openai/responses/empty_retry_runtime.go index 4eec74f..74546b7 100644 --- a/internal/httpapi/openai/responses/empty_retry_runtime.go +++ b/internal/httpapi/openai/responses/empty_retry_runtime.go @@ -1,7 +1,6 @@ package responses import ( - "context" "io" "net/http" "strings" @@ -10,129 +9,10 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" dsprotocol "ds2api/internal/deepseek/protocol" - openaifmt "ds2api/internal/format/openai" "ds2api/internal/promptcompat" - "ds2api/internal/sse" streamengine "ds2api/internal/stream" - "ds2api/internal/toolcall" ) -type responsesNonStreamResult struct { - rawThinking string - rawText string - thinking string - toolDetectionThinking string - text string - contentFilter bool - parsed toolcall.ToolCallParseResult - body map[string]any - responseMessageID int -} - -func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) { - attempts := 0 - currentResp := resp - usagePrompt := finalPrompt - accumulatedThinking := "" - accumulatedRawThinking := "" - accumulatedToolDetectionThinking := "" - for { - result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw) - if !ok { - return - } - accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, result.thinking) - accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, result.rawThinking) - accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, result.toolDetectionThinking) - result.thinking = accumulatedThinking - result.rawThinking = accumulatedRawThinking - result.toolDetectionThinking = accumulatedToolDetectionThinking - result.parsed = detectAssistantToolCalls(result.rawText, result.text, result.rawThinking, result.toolDetectionThinking, toolNames) - result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw) - if refFileTokens > 0 { - addRefFileTokensToUsage(result.body, refFileTokens) - } - - if !shouldRetryResponsesNonStream(result, attempts) { - h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID) - return - } - - attempts++ - config.Logger.Info("[openai_empty_retry] attempting synthetic retry", "surface", "responses", "stream", false, "retry_attempt", attempts, "parent_message_id", result.responseMessageID) - retryPow, powErr := h.DS.GetPow(ctx, a, 3) - if powErr != nil { - config.Logger.Warn("[openai_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", powErr) - retryPow = pow - } - nextResp, err := h.DS.CallCompletion(ctx, a, clonePayloadForEmptyOutputRetry(payload, result.responseMessageID), retryPow, 3) - if err != nil { - writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") - config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err) - return - } - usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts) - currentResp = nextResp - } -} - -func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) { - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) - return responsesNonStreamResult{}, false - } - result := sse.CollectStream(resp, thinkingEnabled, false) - stripReferenceMarkers := h.compatStripReferenceMarkers() - sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers) - sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers) - if searchEnabled { - sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks) - } - textParsed := detectAssistantToolCalls(result.Text, sanitizedText, result.Thinking, result.ToolDetectionThinking, toolNames) - responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw) - return responsesNonStreamResult{ - rawThinking: result.Thinking, - rawText: result.Text, - thinking: sanitizedThinking, - toolDetectionThinking: result.ToolDetectionThinking, - text: sanitizedText, - contentFilter: result.ContentFilter, - parsed: textParsed, - body: responseObj, - responseMessageID: result.ResponseMessageID, - }, true -} - -func (h *Handler) finishResponsesNonStreamResult(w http.ResponseWriter, result responsesNonStreamResult, attempts int, owner, responseID string, toolChoice promptcompat.ToolChoicePolicy, traceID string) { - if len(result.parsed.Calls) == 0 && writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter) { - config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter) - return - } - logResponsesToolPolicyRejection(traceID, toolChoice, result.parsed, "text") - if toolChoice.IsRequired() && len(result.parsed.Calls) == 0 { - writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") - return - } - h.getResponseStore().put(owner, responseID, result.body) - writeJSON(w, http.StatusOK, result.body) - source := "first_attempt" - if attempts > 0 { - source = "synthetic_retry" - } - config.Logger.Info("[openai_empty_retry] completed", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", source) -} - -func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int) bool { - return emptyOutputRetryEnabled() && - attempts < emptyOutputRetryMaxAttempts() && - !result.contentFilter && - len(result.parsed.Calls) == 0 && - strings.TrimSpace(result.text) == "" && - strings.TrimSpace(result.thinking) == "" -} - func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) { streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID) if !ok { diff --git a/internal/httpapi/openai/responses/responses_handler.go b/internal/httpapi/openai/responses/responses_handler.go index 3fc1561..5ec5efe 100644 --- a/internal/httpapi/openai/responses/responses_handler.go +++ b/internal/httpapi/openai/responses/responses_handler.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "ds2api/internal/auth" + "ds2api/internal/completionruntime" "ds2api/internal/config" dsprotocol "ds2api/internal/deepseek/protocol" openaifmt "ds2api/internal/format/openai" @@ -92,34 +93,31 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { return } - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - if a.UseConfigToken { - writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") - } else { - writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if !stdReq.Stream { + result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{ + StripReferenceMarkers: h.compatStripReferenceMarkers(), + RetryEnabled: true, + }) + if outErr != nil { + writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code) + return } - return - } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") - return - } - payload := stdReq.CompletionPayload(sessionID) - resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) - if err != nil { - writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw) + responseObj["usage"] = responsesUsageFromTurn(result.Turn) + h.getResponseStore().put(owner, responseID, responseObj) + writeJSON(w, http.StatusOK, responseObj) return } - responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") - refFileTokens := stdReq.RefFileTokens - if stdReq.Stream { - h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) + start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{}) + if outErr != nil { + writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code) return } - h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) + + refFileTokens := stdReq.RefFileTokens + h.handleResponsesStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) } func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) { diff --git a/internal/httpapi/openai/responses/responses_stream_runtime_core.go b/internal/httpapi/openai/responses/responses_stream_runtime_core.go index cfe9d5d..c043dc1 100644 --- a/internal/httpapi/openai/responses/responses_stream_runtime_core.go +++ b/internal/httpapi/openai/responses/responses_stream_runtime_core.go @@ -1,6 +1,7 @@ package responses import ( + "ds2api/internal/assistantturn" "ds2api/internal/toolcall" "net/http" "strings" @@ -159,9 +160,29 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput finalThinking := s.accumulator.Thinking.String() finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String() - finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers) - textParsed := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames) - detected := textParsed.Calls + finalText := s.accumulator.Text.String() + turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{ + RawText: s.accumulator.RawText.String(), + VisibleText: finalText, + RawThinking: s.accumulator.RawThinking.String(), + VisibleThinking: finalThinking, + DetectionThinking: finalToolDetectionThinking, + ContentFilter: finishReason == "content_filter", + ResponseMessageID: s.responseMessageID, + AlreadyEmittedCalls: s.toolCallsEmitted, + AlreadyEmittedToolRaw: s.toolCallsDoneEmitted, + }, assistantturn.BuildOptions{ + Model: s.model, + Prompt: s.finalPrompt, + RefFileTokens: s.refFileTokens, + SearchEnabled: s.searchEnabled, + StripReferenceMarkers: s.stripReferenceMarkers, + ToolNames: s.toolNames, + ToolsRaw: s.toolsRaw, + ToolChoice: s.toolChoice, + }) + textParsed := turn.ParsedToolCalls + detected := turn.ToolCalls s.logToolPolicyRejections(textParsed) if len(detected) > 0 { @@ -173,12 +194,15 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput s.closeMessageItem() - if s.toolChoice.IsRequired() && len(detected) == 0 { - s.failResponse(http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") + if turn.Error != nil && turn.Error.Code == "tool_choice_violation" { + s.failResponse(turn.Error.Status, turn.Error.Message, turn.Error.Code) return true } - if len(detected) == 0 && strings.TrimSpace(finalText) == "" { - status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking) + if len(detected) == 0 && strings.TrimSpace(turn.Text) == "" { + status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking) + if turn.Error != nil { + status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code + } if deferEmptyOutput { s.finalErrorStatus = status s.finalErrorMessage = message @@ -190,7 +214,7 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput } s.closeIncompleteFunctionItems() - obj := s.buildCompletedResponseObject(finalThinking, finalText, detected) + obj := s.buildCompletedResponseObject(turn.Thinking, turn.Text, detected) if s.persistResponse != nil { s.persistResponse(obj) } @@ -199,6 +223,14 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput return true } +func responsesUsageFromTurn(turn assistantturn.Turn) map[string]any { + return map[string]any{ + "input_tokens": turn.Usage.InputTokens, + "output_tokens": turn.Usage.OutputTokens, + "total_tokens": turn.Usage.TotalTokens, + } +} + func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed toolcall.ToolCallParseResult) { logRejected := func(parsed toolcall.ToolCallParseResult, channel string) { rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)