diff --git a/internal/adapter/claude/handler_messages.go b/internal/adapter/claude/handler_messages.go index 5b553dc..ced0dc1 100644 --- a/internal/adapter/claude/handler_messages.go +++ b/internal/adapter/claude/handler_messages.go @@ -90,6 +90,11 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { result.Text, stdReq.ToolNames, ) + if result.OutputTokens > 0 { + if usage, ok := respBody["usage"].(map[string]any); ok { + usage["output_tokens"] = result.OutputTokens + } + } writeJSON(w, http.StatusOK, respBody) } diff --git a/internal/adapter/claude/stream_runtime_core.go b/internal/adapter/claude/stream_runtime_core.go index a3dd649..e5be865 100644 --- a/internal/adapter/claude/stream_runtime_core.go +++ b/internal/adapter/claude/stream_runtime_core.go @@ -26,6 +26,7 @@ type claudeStreamRuntime struct { messageID string thinking strings.Builder text strings.Builder + outputTokens int nextBlockIndex int thinkingBlockOpen bool @@ -66,6 +67,9 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.OutputTokens > 0 { + s.outputTokens = parsed.OutputTokens + } if parsed.ErrorMessage != "" { s.upstreamErr = parsed.ErrorMessage return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")} diff --git a/internal/adapter/claude/stream_runtime_finalize.go b/internal/adapter/claude/stream_runtime_finalize.go index 0aff357..6a020ef 100644 --- a/internal/adapter/claude/stream_runtime_finalize.go +++ b/internal/adapter/claude/stream_runtime_finalize.go @@ -108,6 +108,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { } outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + if s.outputTokens > 0 { + outputTokens = s.outputTokens + } s.send("message_delta", map[string]any{ "type": "message_delta", "delta": map[string]any{ diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go index a6d85b5..d2f33f1 100644 --- a/internal/adapter/gemini/handler_generate.go +++ b/internal/adapter/gemini/handler_generate.go @@ -174,12 +174,12 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht } result := sse.CollectStream(resp, thinkingEnabled, true) - writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames)) + writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames, result.OutputTokens)) } -func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any { parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) - usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens) return map[string]any{ "candidates": []map[string]any{ { @@ -196,10 +196,14 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final } } -func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { +func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any { promptTokens := util.EstimateTokens(finalPrompt) reasoningTokens := util.EstimateTokens(finalThinking) completionTokens := util.EstimateTokens(finalText) + if outputTokens > 0 { + completionTokens = outputTokens + reasoningTokens = 0 + } return map[string]any{ "promptTokenCount": promptTokens, "candidatesTokenCount": reasoningTokens + completionTokens, diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go index c6a6bcd..1fd9021 100644 --- a/internal/adapter/gemini/handler_stream_runtime.go +++ b/internal/adapter/gemini/handler_stream_runtime.go @@ -64,6 +64,7 @@ type geminiStreamRuntime struct { thinking strings.Builder text strings.Builder + outputTokens int } func newGeminiStreamRuntime( @@ -103,6 +104,9 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.OutputTokens > 0 { + s.outputTokens = parsed.OutputTokens + } if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { return streamengine.ParsedDecision{Stop: true} } @@ -176,6 +180,6 @@ func (s *geminiStreamRuntime) finalize() { }, }, "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens), }) } diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index 3a25f79..563b0f2 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -36,6 +36,7 @@ type chatStreamRuntime struct { streamToolNames map[int]string thinking strings.Builder text strings.Builder + outputTokens int } func newChatStreamRuntime( @@ -165,12 +166,19 @@ func (s *chatStreamRuntime) finalize(finishReason string) { if len(detected.Calls) > 0 || s.toolCallsEmitted { finishReason = "tool_calls" } + usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText) + if s.outputTokens > 0 { + usage["completion_tokens"] = s.outputTokens + if prompt, ok := usage["prompt_tokens"].(int); ok { + usage["total_tokens"] = prompt + s.outputTokens + } + } s.sendChunk(openaifmt.BuildChatStreamChunk( s.completionID, s.created, s.model, []map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)}, - openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText), + usage, )) s.sendDone() } @@ -179,7 +187,13 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !parsed.Parsed { return streamengine.ParsedDecision{} } - if parsed.ContentFilter || parsed.ErrorMessage != "" { + if parsed.OutputTokens > 0 { + s.outputTokens = parsed.OutputTokens + } + if parsed.ContentFilter { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested} + } + if parsed.ErrorMessage != "" { return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")} } if parsed.Stop { diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go index 8847097..1b2fec4 100644 --- a/internal/adapter/openai/handler_chat.go +++ b/internal/adapter/openai/handler_chat.go @@ -107,6 +107,14 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := sanitizeLeakedOutput(result.Text) respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + if result.OutputTokens > 0 { + if usage, ok := respBody["usage"].(map[string]any); ok { + usage["completion_tokens"] = result.OutputTokens + if prompt, ok := usage["prompt_tokens"].(int); ok { + usage["total_tokens"] = prompt + result.OutputTokens + } + } + } writeJSON(w, http.StatusOK, respBody) } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index a7d0828..ed2c715 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -124,6 +124,14 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res } responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames) + if result.OutputTokens > 0 { + if usage, ok := responseObj["usage"].(map[string]any); ok { + usage["output_tokens"] = result.OutputTokens + if input, ok := usage["input_tokens"].(int); ok { + usage["total_tokens"] = input + result.OutputTokens + } + } + } h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index 460ce2a..eaae51b 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -49,6 +49,7 @@ type responsesStreamRuntime struct { messagePartAdded bool sequence int failed bool + outputTokens int persistResponse func(obj map[string]any) } @@ -144,6 +145,14 @@ func (s *responsesStreamRuntime) finalize() { s.closeIncompleteFunctionItems() obj := s.buildCompletedResponseObject(finalThinking, finalText, detected) + if s.outputTokens > 0 { + if usage, ok := obj["usage"].(map[string]any); ok { + usage["output_tokens"] = s.outputTokens + if input, ok := usage["input_tokens"].(int); ok { + usage["total_tokens"] = input + s.outputTokens + } + } + } if s.persistResponse != nil { s.persistResponse(obj) } @@ -172,6 +181,9 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.OutputTokens > 0 { + s.outputTokens = parsed.OutputTokens + } if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { return streamengine.ParsedDecision{Stop: true} } diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go index c76d881..2a3584b 100644 --- a/internal/adapter/openai/stream_status_test.go +++ b/internal/adapter/openai/stream_status_test.go @@ -183,3 +183,53 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) { t.Fatalf("expected function_call output item, got %#v", output) } } + +func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse( + `data: {"p":"response/content","v":"合法前缀"}`, + `data: {"p":"response/status","v":"CONTENT_FILTER","accumulated_token_usage":77}`, + `data: {"p":"response/content","v":"CONTENT_FILTER你好,这个问题我暂时无法回答,让我们换个话题再聊聊吧。"}`, + )}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + if strings.Contains(rec.Body.String(), "这个问题我暂时无法回答") { + t.Fatalf("expected leaked content-filter suffix to be hidden, body=%s", rec.Body.String()) + } + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if len(frames) == 0 { + t.Fatalf("expected at least one json frame, body=%s", rec.Body.String()) + } + last := frames[len(frames)-1] + choices, _ := last["choices"].([]any) + if len(choices) != 1 { + t.Fatalf("expected one choice in final frame, got %#v", last) + } + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop for content-filter upstream stop, got %#v", choice["finish_reason"]) + } +} diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go index 9e0e180..c4a1e00 100644 --- a/internal/sse/consumer.go +++ b/internal/sse/consumer.go @@ -10,8 +10,9 @@ import ( // CollectResult holds the aggregated text and thinking content from a // DeepSeek SSE stream, consumed to completion (non-streaming use case). type CollectResult struct { - Text string - Thinking string + Text string + Thinking string + OutputTokens int } // CollectStream fully consumes a DeepSeek SSE response and separates @@ -26,6 +27,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co } text := strings.Builder{} thinking := strings.Builder{} + outputTokens := 0 currentType := "text" if thinkingEnabled { currentType = "thinking" @@ -37,8 +39,14 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co return true } if result.Stop { + if result.OutputTokens > 0 { + outputTokens = result.OutputTokens + } return false } + if result.OutputTokens > 0 { + outputTokens = result.OutputTokens + } for _, p := range result.Parts { if p.Type == "thinking" { thinking.WriteString(p.Text) @@ -48,5 +56,5 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co } return true }) - return CollectResult{Text: text.String(), Thinking: thinking.String()} + return CollectResult{Text: text.String(), Thinking: thinking.String(), OutputTokens: outputTokens} } diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go index 8f78f01..54f841b 100644 --- a/internal/sse/consumer_edge_test.go +++ b/internal/sse/consumer_edge_test.go @@ -138,3 +138,15 @@ func TestCollectStreamStatusFinished(t *testing.T) { t.Fatalf("expected 'Hello', got %q", result.Text) } } + +func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" + + "data: {\"p\":\"response/status\",\"v\":\"CONTENT_FILTER\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"blocked\"}\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "safe" { + t.Fatalf("expected stream to stop before blocked tail, got %q", result.Text) + } +} diff --git a/internal/sse/line.go b/internal/sse/line.go index e63f378..1d9ddae 100644 --- a/internal/sse/line.go +++ b/internal/sse/line.go @@ -10,6 +10,7 @@ type LineResult struct { ErrorMessage string Parts []ContentPart NextType string + OutputTokens int } // ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both @@ -35,8 +36,17 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Parsed: true, Stop: true, ContentFilter: true, - ErrorMessage: "content filtered by upstream", NextType: currentType, + OutputTokens: extractAccumulatedTokenUsage(chunk), + } + } + if hasContentFilterStatus(chunk) { + return LineResult{ + Parsed: true, + Stop: true, + ContentFilter: true, + NextType: currentType, + OutputTokens: extractAccumulatedTokenUsage(chunk), } } parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) @@ -46,5 +56,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: finished, Parts: parts, NextType: nextType, + OutputTokens: extractAccumulatedTokenUsage(chunk), } } diff --git a/internal/sse/line_edge_test.go b/internal/sse/line_edge_test.go index 2ae53a6..4d507fc 100644 --- a/internal/sse/line_edge_test.go +++ b/internal/sse/line_edge_test.go @@ -40,8 +40,8 @@ func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) { if !res.ContentFilter { t.Fatal("expected content filter flag") } - if res.ErrorMessage == "" { - t.Fatal("expected error message on content filter") + if res.ErrorMessage != "" { + t.Fatalf("expected empty error message on content filter, got %q", res.ErrorMessage) } } diff --git a/internal/sse/line_test.go b/internal/sse/line_test.go index 4e1d22a..7f2baa6 100644 --- a/internal/sse/line_test.go +++ b/internal/sse/line_test.go @@ -26,6 +26,33 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) { } } +func TestParseDeepSeekContentLineContentFilterCodeIncludesOutputTokens(t *testing.T) { + res := ParseDeepSeekContentLine( + []byte(`data: {"code":"content_filter","accumulated_token_usage":99}`), + false, "text", + ) + if !res.Parsed || !res.Stop || !res.ContentFilter { + t.Fatalf("expected content-filter stop result: %#v", res) + } + if res.OutputTokens != 99 { + t.Fatalf("expected output token usage 99, got %d", res.OutputTokens) + } +} + +func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/status","v":"CONTENT_FILTER"}`), false, "text") + if !res.Parsed || !res.Stop || !res.ContentFilter { + t.Fatalf("expected status-based content-filter stop result: %#v", res) + } +} + +func TestParseDeepSeekContentLineCapturesAccumulatedTokenUsage(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":1383},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text") + if res.OutputTokens != 1383 { + t.Fatalf("expected output token usage 1383, got %d", res.OutputTokens) + } +} + func TestParseDeepSeekContentLineContent(t *testing.T) { res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text") if !res.Parsed || res.Stop { @@ -65,3 +92,13 @@ func TestParseDeepSeekContentLineTrimsFromContentFilterKeyword(t *testing.T) { t.Fatalf("unexpected parts after filter: %#v", res.Parts) } } + +func TestParseDeepSeekContentLineContentTextEqualContentFilterDoesNotStop(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"content_filter"}`), false, "text") + if !res.Parsed { + t.Fatalf("expected parsed result: %#v", res) + } + if res.Stop || res.ContentFilter { + t.Fatalf("did not expect content-filter stop for content text: %#v", res) + } +} diff --git a/internal/sse/parser.go b/internal/sse/parser.go index c20bc79..1074a34 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -3,6 +3,7 @@ package sse import ( "bytes" "encoding/json" + "math" "strings" "ds2api/internal/deepseek" @@ -287,3 +288,90 @@ func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bo func IsCitation(text string) bool { return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:")) } + +func hasContentFilterStatus(chunk map[string]any) bool { + if code, _ := chunk["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") { + return true + } + return hasContentFilterStatusValue(chunk) +} + +func hasContentFilterStatusValue(v any) bool { + switch x := v.(type) { + case []any: + for _, item := range x { + if hasContentFilterStatusValue(item) { + return true + } + } + case map[string]any: + if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "status") { + if s, _ := x["v"].(string); strings.EqualFold(strings.TrimSpace(s), "content_filter") { + return true + } + } + if code, _ := x["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") { + return true + } + for _, vv := range x { + if hasContentFilterStatusValue(vv) { + return true + } + } + } + return false +} + +func extractAccumulatedTokenUsage(chunk map[string]any) int { + return findAccumulatedTokenUsage(chunk) +} + +func findAccumulatedTokenUsage(v any) int { + switch x := v.(type) { + case map[string]any: + if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") { + if n, ok := toInt(x["v"]); ok && n > 0 { + return n + } + } + if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 { + return n + } + for _, vv := range x { + if n := findAccumulatedTokenUsage(vv); n > 0 { + return n + } + } + case []any: + for _, item := range x { + if n := findAccumulatedTokenUsage(item); n > 0 { + return n + } + } + } + return 0 +} + +func toInt(v any) (int, bool) { + switch x := v.(type) { + case int: + return x, true + case int32: + return int(x), true + case int64: + return int(x), true + case float64: + if math.IsNaN(x) || math.IsInf(x, 0) { + return 0, false + } + return int(x), true + case json.Number: + i, err := x.Int64() + if err != nil { + return 0, false + } + return int(i), true + default: + return 0, false + } +} diff --git a/internal/translatorcliproxy/stream_writer.go b/internal/translatorcliproxy/stream_writer.go index b1285b1..07c4bcb 100644 --- a/internal/translatorcliproxy/stream_writer.go +++ b/internal/translatorcliproxy/stream_writer.go @@ -62,6 +62,18 @@ func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) { if len(trimmed) == 0 { continue } + if bytes.HasPrefix(trimmed, []byte(":")) { + if _, err := w.dst.Write(trimmed); err != nil { + return len(p), err + } + if _, err := w.dst.Write([]byte("\n\n")); err != nil { + return len(p), err + } + if f, ok := w.dst.(http.Flusher); ok { + f.Flush() + } + continue + } if !bytes.HasPrefix(trimmed, []byte("data:")) { continue } diff --git a/internal/translatorcliproxy/stream_writer_test.go b/internal/translatorcliproxy/stream_writer_test.go index 31a4aa3..979d36e 100644 --- a/internal/translatorcliproxy/stream_writer_test.go +++ b/internal/translatorcliproxy/stream_writer_test.go @@ -42,3 +42,16 @@ func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) { t.Fatalf("expected gemini stream payload, got: %s", body) } } + +func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) { + rec := httptest.NewRecorder() + w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatGemini, "gemini-2.5-pro", []byte(`{}`), []byte(`{}`)) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + _, _ = w.Write([]byte(": keep-alive\n\n")) + + body := rec.Body.String() + if !strings.Contains(body, ": keep-alive\n\n") { + t.Fatalf("expected keep-alive comment passthrough, got %q", body) + } +}