From b79a13efd56f43d91cb1f4611a650d9568b70071 Mon Sep 17 00:00:00 2001 From: CJACK Date: Tue, 7 Apr 2026 01:39:27 +0800 Subject: [PATCH] feat: support explicit prompt token tracking in SSE parsing and stream handlers --- .../adapter/openai/chat_stream_runtime.go | 17 ++++-- internal/adapter/openai/handler_chat.go | 13 ++-- .../adapter/openai/handler_toolcall_format.go | 2 +- .../openai/responses_stream_runtime_events.go | 2 +- internal/deepseek/constants_shared.json | 1 - internal/js/chat-stream/sse_parse_impl.js | 60 ++++++++++++++----- internal/js/chat-stream/token_usage.js | 10 ++-- internal/js/chat-stream/vercel_stream_impl.js | 6 +- internal/sse/consumer.go | 15 +++-- internal/sse/line.go | 9 ++- internal/sse/parser.go | 34 ++++++++--- internal/sse/parser_edge_test.go | 16 +---- internal/sse/parser_test.go | 14 +++++ 13 files changed, 136 insertions(+), 63 deletions(-) diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index a199882..47f483a 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -37,6 +37,7 @@ type chatStreamRuntime struct { streamToolNames map[int]string thinking strings.Builder text strings.Builder + promptTokens int outputTokens int } @@ -170,11 +171,16 @@ func (s *chatStreamRuntime) finalize(finishReason string) { finishReason = "tool_calls" } usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText) + if s.promptTokens > 0 { + usage["prompt_tokens"] = s.promptTokens + } if s.outputTokens > 0 { usage["completion_tokens"] = s.outputTokens - if prompt, ok := usage["prompt_tokens"].(int); ok { - usage["total_tokens"] = prompt + s.outputTokens - } + } + if s.promptTokens > 0 || s.outputTokens > 0 { + p := usage["prompt_tokens"].(int) + c := usage["completion_tokens"].(int) + usage["total_tokens"] = p + c } s.sendChunk(openaifmt.BuildChatStreamChunk( s.completionID, @@ -190,6 +196,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.PromptTokens > 0 { + s.promptTokens = parsed.PromptTokens + } if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } @@ -243,7 +252,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !s.emitEarlyToolDeltas { continue } - filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames) + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.streamToolNames) if len(filtered) == 0 { continue } diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go index 95337b6..e28886d 100644 --- a/internal/adapter/openai/handler_chat.go +++ b/internal/adapter/openai/handler_chat.go @@ -131,12 +131,17 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re return } respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) - if result.OutputTokens > 0 { + if result.PromptTokens > 0 || result.OutputTokens > 0 { if usage, ok := respBody["usage"].(map[string]any); ok { - usage["completion_tokens"] = result.OutputTokens - if prompt, ok := usage["prompt_tokens"].(int); ok { - usage["total_tokens"] = prompt + result.OutputTokens + if result.PromptTokens > 0 { + usage["prompt_tokens"] = result.PromptTokens } + if result.OutputTokens > 0 { + usage["completion_tokens"] = result.OutputTokens + } + p, _ := usage["prompt_tokens"].(int) + c, _ := usage["completion_tokens"].(int) + usage["total_tokens"] = p + c } } writeJSON(w, http.StatusOK, respBody) diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index 44eb4d1..3937610 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -113,7 +113,7 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s return out } -func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta { +func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, seenNames map[int]string) []toolCallDelta { if len(deltas) == 0 { return nil } diff --git a/internal/adapter/openai/responses_stream_runtime_events.go b/internal/adapter/openai/responses_stream_runtime_events.go index 792d0ce..21e15d1 100644 --- a/internal/adapter/openai/responses_stream_runtime_events.go +++ b/internal/adapter/openai/responses_stream_runtime_events.go @@ -48,7 +48,7 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven if !s.emitEarlyToolDeltas { continue } - filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames) + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.functionNames) if len(filtered) == 0 { continue } diff --git a/internal/deepseek/constants_shared.json b/internal/deepseek/constants_shared.json index a71ca02..56950ca 100644 --- a/internal/deepseek/constants_shared.json +++ b/internal/deepseek/constants_shared.json @@ -12,7 +12,6 @@ "skip_contains_patterns": [ "quasi_status", "elapsed_secs", - "token_usage", "pending_fragment", "conversation_mode", "fragments/-1/status", diff --git a/internal/js/chat-stream/sse_parse_impl.js b/internal/js/chat-stream/sse_parse_impl.js index f24ee6d..10b85f0 100644 --- a/internal/js/chat-stream/sse_parse_impl.js +++ b/internal/js/chat-stream/sse_parse_impl.js @@ -20,7 +20,9 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc }; } - const outputTokens = extractAccumulatedTokenUsage(chunk); + const usage = extractAccumulatedTokenUsage(chunk); + const promptTokens = usage.prompt; + const outputTokens = usage.output; if (Object.prototype.hasOwnProperty.call(chunk, 'error')) { return { @@ -29,7 +31,8 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: formatErrorMessage(chunk.error), - outputTokens: 0, + promptTokens, + outputTokens, newType: currentType, }; } @@ -43,6 +46,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: true, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -55,6 +59,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -67,6 +72,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -77,6 +83,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -89,6 +96,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -157,6 +165,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -168,6 +177,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -182,6 +192,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -196,6 +207,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -207,6 +219,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -242,6 +255,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -429,40 +443,54 @@ function hasContentFilterStatusValue(v) { } function extractAccumulatedTokenUsage(chunk) { - return findAccumulatedTokenUsage(chunk); + const usage = findAccumulatedTokenUsage(chunk); + return usage || { prompt: 0, output: 0 }; } function findAccumulatedTokenUsage(v) { if (Array.isArray(v)) { for (const item of v) { - const n = findAccumulatedTokenUsage(item); - if (n > 0) { - return n; - } + const u = findAccumulatedTokenUsage(item); + if (u) return u; } - return 0; + return null; } if (!v || typeof v !== 'object') { - return 0; + return null; } const pathValue = asString(v.p); if (pathValue && pathValue.toLowerCase().includes('accumulated_token_usage')) { const n = toInt(v.v); if (n > 0) { - return n; + return { prompt: 0, output: n }; + } + } + if (pathValue && pathValue.toLowerCase().includes('token_usage')) { + const u = v.v; + if (u && typeof u === 'object') { + const p = toInt(u.prompt_tokens); + const c = toInt(u.completion_tokens); + if (p > 0 || c > 0) { + return { prompt: p, output: c }; + } } } const direct = toInt(v.accumulated_token_usage); if (direct > 0) { - return direct; + return { prompt: 0, output: direct }; } - for (const value of Object.values(v)) { - const n = findAccumulatedTokenUsage(value); - if (n > 0) { - return n; + if (v.token_usage && typeof v.token_usage === 'object') { + const p = toInt(v.token_usage.prompt_tokens); + const c = toInt(v.token_usage.completion_tokens); + if (p > 0 || c > 0) { + return { prompt: p, output: c }; } } - return 0; + for (const value of Object.values(v)) { + const u = findAccumulatedTokenUsage(value); + if (u) return u; + } + return null; } function toInt(v) { diff --git a/internal/js/chat-stream/token_usage.js b/internal/js/chat-stream/token_usage.js index 0f71c5f..82e12e8 100644 --- a/internal/js/chat-stream/token_usage.js +++ b/internal/js/chat-stream/token_usage.js @@ -1,15 +1,17 @@ 'use strict'; -function buildUsage(prompt, thinking, output, outputTokens = 0) { - const promptTokens = estimateTokens(prompt); +function buildUsage(prompt, thinking, output, outputTokens = 0, providedPromptTokens = 0) { const reasoningTokens = estimateTokens(thinking); const completionTokens = estimateTokens(output); + + const finalPromptTokens = Number.isFinite(providedPromptTokens) && providedPromptTokens > 0 ? Math.trunc(providedPromptTokens) : estimateTokens(prompt); + const overriddenCompletionTokens = Number.isFinite(outputTokens) && outputTokens > 0 ? Math.trunc(outputTokens) : 0; const finalCompletionTokens = overriddenCompletionTokens > 0 ? overriddenCompletionTokens : reasoningTokens + completionTokens; return { - prompt_tokens: promptTokens, + prompt_tokens: finalPromptTokens, completion_tokens: finalCompletionTokens, - total_tokens: promptTokens + finalCompletionTokens, + total_tokens: finalPromptTokens + finalCompletionTokens, completion_tokens_details: { reasoning_tokens: reasoningTokens, }, diff --git a/internal/js/chat-stream/vercel_stream_impl.js b/internal/js/chat-stream/vercel_stream_impl.js index e46b530..7c39313 100644 --- a/internal/js/chat-stream/vercel_stream_impl.js +++ b/internal/js/chat-stream/vercel_stream_impl.js @@ -125,6 +125,7 @@ async function handleVercelStream(req, res, rawBody, payload) { let currentType = thinkingEnabled ? 'thinking' : 'text'; let thinkingText = ''; let outputText = ''; + let promptTokens = 0; let outputTokens = 0; const toolSieveEnabled = toolPolicy.toolSieveEnabled; const toolSieveState = createToolSieveState(); @@ -178,7 +179,7 @@ async function handleVercelStream(req, res, rawBody, payload) { created, model, choices: [{ delta: {}, index: 0, finish_reason: reason }], - usage: buildUsage(finalPrompt, thinkingText, outputText, outputTokens), + usage: buildUsage(finalPrompt, thinkingText, outputText, outputTokens, promptTokens), }); if (!res.writableEnded && !res.destroyed) { res.write('data: [DONE]\n\n'); @@ -227,6 +228,9 @@ async function handleVercelStream(req, res, rawBody, payload) { if (!parsed.parsed) { continue; } + if (parsed.promptTokens > 0) { + promptTokens = parsed.promptTokens; + } if (parsed.outputTokens > 0) { outputTokens = parsed.outputTokens; } diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go index 141bd93..f11e942 100644 --- a/internal/sse/consumer.go +++ b/internal/sse/consumer.go @@ -12,6 +12,7 @@ import ( type CollectResult struct { Text string Thinking string + PromptTokens int OutputTokens int ContentFilter bool } @@ -28,6 +29,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co } text := strings.Builder{} thinking := strings.Builder{} + promptTokens := 0 outputTokens := 0 contentFilter := false currentType := "text" @@ -40,18 +42,18 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co if !result.Parsed { return true } + if result.PromptTokens > 0 { + promptTokens = result.PromptTokens + } + if result.OutputTokens > 0 { + outputTokens = result.OutputTokens + } if result.Stop { if result.ContentFilter { contentFilter = true } - if result.OutputTokens > 0 { - outputTokens = result.OutputTokens - } return false } - if result.OutputTokens > 0 { - outputTokens = result.OutputTokens - } for _, p := range result.Parts { if p.Type == "thinking" { trimmed := TrimContinuationOverlap(thinking.String(), p.Text) @@ -66,6 +68,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co return CollectResult{ Text: text.String(), Thinking: thinking.String(), + PromptTokens: promptTokens, OutputTokens: outputTokens, ContentFilter: contentFilter, } diff --git a/internal/sse/line.go b/internal/sse/line.go index d55f9e5..a63563b 100644 --- a/internal/sse/line.go +++ b/internal/sse/line.go @@ -10,6 +10,7 @@ type LineResult struct { ErrorMessage string Parts []ContentPart NextType string + PromptTokens int OutputTokens int } @@ -20,9 +21,9 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri if !parsed { return LineResult{NextType: currentType} } - outputTokens := extractAccumulatedTokenUsage(chunk) + promptTokens, outputTokens := extractAccumulatedTokenUsage(chunk) if done { - return LineResult{Parsed: true, Stop: true, NextType: currentType, OutputTokens: outputTokens} + return LineResult{Parsed: true, Stop: true, NextType: currentType, PromptTokens: promptTokens, OutputTokens: outputTokens} } if errObj, hasErr := chunk["error"]; hasErr { return LineResult{ @@ -30,6 +31,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: true, ErrorMessage: fmt.Sprintf("%v", errObj), NextType: currentType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } @@ -39,6 +41,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: true, ContentFilter: true, NextType: currentType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } @@ -48,6 +51,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: true, ContentFilter: true, NextType: currentType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } @@ -58,6 +62,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: finished, Parts: parts, NextType: nextType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } diff --git a/internal/sse/parser.go b/internal/sse/parser.go index eee46f9..051619e 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -364,34 +364,50 @@ func hasContentFilterStatusValue(v any) bool { return false } -func extractAccumulatedTokenUsage(chunk map[string]any) int { +func extractAccumulatedTokenUsage(chunk map[string]any) (int, int) { return findAccumulatedTokenUsage(chunk) } -func findAccumulatedTokenUsage(v any) int { +func findAccumulatedTokenUsage(v any) (int, int) { switch x := v.(type) { case map[string]any: if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") { if n, ok := toInt(x["v"]); ok && n > 0 { - return n + return 0, n + } + } + if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "token_usage") { + if m, ok := x["v"].(map[string]any); ok { + p, _ := toInt(m["prompt_tokens"]) + c, _ := toInt(m["completion_tokens"]) + if p > 0 || c > 0 { + return p, c + } } } if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 { - return n + return 0, n + } + if usage, ok := x["token_usage"].(map[string]any); ok { + p, _ := toInt(usage["prompt_tokens"]) + c, _ := toInt(usage["completion_tokens"]) + if p > 0 || c > 0 { + return p, c + } } for _, vv := range x { - if n := findAccumulatedTokenUsage(vv); n > 0 { - return n + if p, c := findAccumulatedTokenUsage(vv); p > 0 || c > 0 { + return p, c } } case []any: for _, item := range x { - if n := findAccumulatedTokenUsage(item); n > 0 { - return n + if p, c := findAccumulatedTokenUsage(item); p > 0 || c > 0 { + return p, c } } } - return 0 + return 0, 0 } func toInt(v any) (int, bool) { diff --git a/internal/sse/parser_edge_test.go b/internal/sse/parser_edge_test.go index ba1c723..f0e7f9a 100644 --- a/internal/sse/parser_edge_test.go +++ b/internal/sse/parser_edge_test.go @@ -50,18 +50,6 @@ func TestShouldSkipPathQuasiStatus(t *testing.T) { } } -func TestShouldSkipPathElapsedSecs(t *testing.T) { - if !shouldSkipPath("response/elapsed_secs") { - t.Fatal("expected skip for elapsed_secs path") - } -} - -func TestShouldSkipPathTokenUsage(t *testing.T) { - if !shouldSkipPath("response/token_usage") { - t.Fatal("expected skip for token_usage path") - } -} - func TestShouldSkipPathPendingFragment(t *testing.T) { if !shouldSkipPath("response/pending_fragment") { t.Fatal("expected skip for pending_fragment path") @@ -127,7 +115,7 @@ func TestParseSSEChunkForContentNoVField(t *testing.T) { func TestParseSSEChunkForContentSkippedPath(t *testing.T) { parts, finished, nextType := ParseSSEChunkForContent(map[string]any{ - "p": "response/token_usage", + "p": "response/quasi_status", "v": "some data", }, false, "text") if finished || len(parts) > 0 { @@ -498,7 +486,7 @@ func TestExtractContentRecursiveFinishedStatus(t *testing.T) { func TestExtractContentRecursiveSkipsPath(t *testing.T) { items := []any{ - map[string]any{"p": "token_usage", "v": "data"}, + map[string]any{"p": "quasi_status", "v": "data"}, } parts, finished := extractContentRecursive(items, "text") if finished { diff --git a/internal/sse/parser_test.go b/internal/sse/parser_test.go index b036f57..89c5356 100644 --- a/internal/sse/parser_test.go +++ b/internal/sse/parser_test.go @@ -19,6 +19,20 @@ func TestParseDeepSeekSSELineDone(t *testing.T) { } } +func TestExtractTokenUsage(t *testing.T) { + chunk := map[string]any{ + "p": "response/token_usage", + "v": map[string]any{ + "prompt_tokens": 123, + "completion_tokens": 456, + }, + } + p, c := extractAccumulatedTokenUsage(chunk) + if p != 123 || c != 456 { + t.Fatalf("expected 123/456, got %d/%d", p, c) + } +} + func TestParseSSEChunkForContentSimple(t *testing.T) { parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text") if finished {