diff --git a/API.en.md b/API.en.md index 2b83245..0238a42 100644 --- a/API.en.md +++ b/API.en.md @@ -384,6 +384,7 @@ Business auth required. Returns OpenAI-compatible embeddings shape. ## Claude-Compatible API Besides `/anthropic/v1/*`, DS2API also supports shortcut paths: `/v1/messages`, `/messages`, `/v1/messages/count_tokens`, `/messages/count_tokens`. +Implementation-wise this path is unified on the OpenAI Chat Completions parse-and-translate pipeline to avoid maintaining divergent parsing chains. ### `GET /anthropic/v1/models` @@ -518,6 +519,7 @@ Supported paths: - `/v1/models/{model}:streamGenerateContent` (compat path) Authentication is the same as other business routes (`Authorization: Bearer ` or `x-api-key`). +Implementation-wise this path is unified on the OpenAI Chat Completions parse-and-translate pipeline to avoid maintaining divergent parsing chains. ### `POST /v1beta/models/{model}:generateContent` diff --git a/API.md b/API.md index 1caa984..d2eb1f0 100644 --- a/API.md +++ b/API.md @@ -390,6 +390,7 @@ data: [DONE] ## Claude 兼容接口 除标准路径 `/anthropic/v1/*` 外,还支持快捷路径 `/v1/messages`、`/messages`、`/v1/messages/count_tokens`、`/messages/count_tokens`。 +实现上统一走 OpenAI Chat Completions 解析与回译链路,避免多套解析逻辑分叉维护。 ### `GET /anthropic/v1/models` @@ -524,6 +525,7 @@ data: {"type":"message_stop"} - `/v1/models/{model}:streamGenerateContent`(兼容路径) 鉴权方式同业务接口(`Authorization: Bearer ` 或 `x-api-key`)。 +实现上统一走 OpenAI Chat Completions 解析与回译链路,避免多套解析逻辑分叉维护。 ### `POST /v1beta/models/{model}:generateContent` diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go index 56cc0e6..b03b3ea 100644 --- a/internal/adapter/gemini/handler_generate.go +++ b/internal/adapter/gemini/handler_generate.go @@ -149,15 +149,14 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht cleanVisibleOutput(result.Thinking, stripReferenceMarkers), cleanVisibleOutput(result.Text, stripReferenceMarkers), toolNames, - result.PromptTokens, result.OutputTokens, )) } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, promptTokens, outputTokens int) map[string]any { +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any { parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) - usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, promptTokens, outputTokens) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens) return map[string]any{ "candidates": []map[string]any{ { @@ -175,10 +174,8 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiUsage(finalPrompt, finalThinking, finalText string, promptTokens, outputTokens int) map[string]any { - if promptTokens <= 0 { - promptTokens = util.EstimateTokens(finalPrompt) - } +func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) reasoningTokens := util.EstimateTokens(finalThinking) completionTokens := util.EstimateTokens(finalText) if outputTokens > 0 { diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go index b8d2701..e7c9b87 100644 --- a/internal/adapter/gemini/handler_stream_runtime.go +++ b/internal/adapter/gemini/handler_stream_runtime.go @@ -67,7 +67,6 @@ type geminiStreamRuntime struct { thinking strings.Builder text strings.Builder - promptTokens int outputTokens int } @@ -113,9 +112,6 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if !parsed.Parsed { return streamengine.ParsedDecision{} } - if parsed.PromptTokens > 0 { - s.promptTokens = parsed.PromptTokens - } if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } @@ -202,6 +198,6 @@ func (s *geminiStreamRuntime) finalize() { }, }, "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.promptTokens, s.outputTokens), + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens), }) } diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go index aa3ae46..b7aea1b 100644 --- a/internal/adapter/gemini/handler_test.go +++ b/internal/adapter/gemini/handler_test.go @@ -296,32 +296,6 @@ func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) { } } -func TestBuildGeminiUsageOverridesPromptAndOutputTokensWhenProvided(t *testing.T) { - usage := buildGeminiUsage("prompt", "thinking", "answer", 11, 29) - if got, _ := usage["promptTokenCount"].(int); got != 11 { - t.Fatalf("expected promptTokenCount=11, got %#v", usage["promptTokenCount"]) - } - if got, _ := usage["candidatesTokenCount"].(int); got != 29 { - t.Fatalf("expected candidatesTokenCount=29, got %#v", usage["candidatesTokenCount"]) - } - if got, _ := usage["totalTokenCount"].(int); got != 40 { - t.Fatalf("expected totalTokenCount=40, got %#v", usage["totalTokenCount"]) - } -} - -func TestBuildGeminiUsageFallsBackToEstimateWhenNoUpstreamUsage(t *testing.T) { - usage := buildGeminiUsage("abcdef", "", "ghijkl", 0, 0) - if got, _ := usage["promptTokenCount"].(int); got <= 0 { - t.Fatalf("expected positive promptTokenCount estimate, got %#v", usage["promptTokenCount"]) - } - if got, _ := usage["candidatesTokenCount"].(int); got <= 0 { - t.Fatalf("expected positive candidatesTokenCount estimate, got %#v", usage["candidatesTokenCount"]) - } - if got, _ := usage["totalTokenCount"].(int); got <= 0 { - t.Fatalf("expected positive totalTokenCount estimate, got %#v", usage["totalTokenCount"]) - } -} - func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any { t.Helper() scanner := bufio.NewScanner(strings.NewReader(body)) diff --git a/internal/translatorcliproxy/bridge_test.go b/internal/translatorcliproxy/bridge_test.go index 5f0979f..cdd9cf7 100644 --- a/internal/translatorcliproxy/bridge_test.go +++ b/internal/translatorcliproxy/bridge_test.go @@ -26,6 +26,26 @@ func TestFromOpenAINonStreamClaude(t *testing.T) { } } +func TestFromOpenAINonStreamClaudePreservesUsageFromOpenAI(t *testing.T) { + original := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`) + translatedReq := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`) + openaibody := []byte(`{"id":"chatcmpl_1","object":"chat.completion","created":1,"model":"claude-sonnet-4-5","choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":29,"total_tokens":40}}`) + got := string(FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", original, translatedReq, openaibody)) + if !strings.Contains(got, `"input_tokens":11`) || !strings.Contains(got, `"output_tokens":29`) { + t.Fatalf("expected claude usage to preserve prompt/completion tokens, got: %s", got) + } +} + +func TestFromOpenAINonStreamGeminiPreservesUsageFromOpenAI(t *testing.T) { + original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + translatedReq := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":false}`) + openaibody := []byte(`{"id":"chatcmpl_1","object":"chat.completion","created":1,"model":"gemini-2.5-pro","choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":29,"total_tokens":40}}`) + got := string(FromOpenAINonStream(sdktranslator.FormatGemini, "gemini-2.5-pro", original, translatedReq, openaibody)) + if !strings.Contains(got, `"promptTokenCount":11`) || !strings.Contains(got, `"candidatesTokenCount":29`) || !strings.Contains(got, `"totalTokenCount":40`) { + t.Fatalf("expected gemini usageMetadata to preserve prompt/completion tokens, got: %s", got) + } +} + func TestParseFormatAliases(t *testing.T) { cases := map[string]sdktranslator.Format{ "responses": sdktranslator.FormatOpenAIResponse, diff --git a/internal/translatorcliproxy/stream_writer.go b/internal/translatorcliproxy/stream_writer.go index 07c4bcb..b1b8747 100644 --- a/internal/translatorcliproxy/stream_writer.go +++ b/internal/translatorcliproxy/stream_writer.go @@ -3,7 +3,9 @@ package translatorcliproxy import ( "bytes" "context" + "encoding/json" "net/http" + "strings" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" ) @@ -77,7 +79,13 @@ func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) { if !bytes.HasPrefix(trimmed, []byte("data:")) { continue } + usage, hasUsage := extractOpenAIUsage(trimmed) chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, w.target, w.model, w.originalReq, w.translatedReq, trimmed, &w.param) + if hasUsage { + for i := range chunks { + chunks[i] = injectStreamUsageMetadata(chunks[i], w.target, usage) + } + } for i := range chunks { if len(chunks[i]) == 0 { continue @@ -118,3 +126,92 @@ func (w *OpenAIStreamTranslatorWriter) readOneLine() ([]byte, bool) { w.lineBuf.Next(idx + 1) return line, true } + +type openAIUsage struct { + PromptTokens int + CompletionTokens int + TotalTokens int +} + +func extractOpenAIUsage(line []byte) (openAIUsage, bool) { + raw := strings.TrimSpace(strings.TrimPrefix(string(line), "data:")) + if raw == "" || raw == "[DONE]" { + return openAIUsage{}, false + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return openAIUsage{}, false + } + usageObj, _ := payload["usage"].(map[string]any) + if usageObj == nil { + return openAIUsage{}, false + } + p := toInt(usageObj["prompt_tokens"]) + c := toInt(usageObj["completion_tokens"]) + t := toInt(usageObj["total_tokens"]) + if p <= 0 && c <= 0 && t <= 0 { + return openAIUsage{}, false + } + if t <= 0 { + t = p + c + } + return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true +} + +func injectStreamUsageMetadata(chunk []byte, target sdktranslator.Format, usage openAIUsage) []byte { + if target != sdktranslator.FormatGemini { + return chunk + } + text := strings.TrimSpace(string(chunk)) + if text == "" { + return chunk + } + var ( + hasDataPrefix bool + jsonText = text + ) + if strings.HasPrefix(jsonText, "data:") { + hasDataPrefix = true + jsonText = strings.TrimSpace(strings.TrimPrefix(jsonText, "data:")) + } + if jsonText == "" || jsonText == "[DONE]" { + return chunk + } + obj := map[string]any{} + if err := json.Unmarshal([]byte(jsonText), &obj); err != nil { + return chunk + } + if _, ok := obj["candidates"]; !ok { + return chunk + } + obj["usageMetadata"] = map[string]any{ + "promptTokenCount": usage.PromptTokens, + "candidatesTokenCount": usage.CompletionTokens, + "totalTokenCount": usage.TotalTokens, + } + b, err := json.Marshal(obj) + if err != nil { + return chunk + } + if hasDataPrefix { + return []byte("data: " + string(b)) + } + return b +} + +func toInt(v any) int { + switch x := v.(type) { + case int: + return x + case int32: + return int(x) + case int64: + return int(x) + case float64: + return int(x) + case float32: + return int(x) + default: + return 0 + } +} diff --git a/internal/translatorcliproxy/stream_writer_test.go b/internal/translatorcliproxy/stream_writer_test.go index 979d36e..77d2936 100644 --- a/internal/translatorcliproxy/stream_writer_test.go +++ b/internal/translatorcliproxy/stream_writer_test.go @@ -18,12 +18,16 @@ func TestOpenAIStreamTranslatorWriterClaude(t *testing.T) { w.WriteHeader(200) _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n")) _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":29,\"total_tokens\":40}}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) body := rec.Body.String() if !strings.Contains(body, "event: message_start") { t.Fatalf("expected claude message_start event, got: %s", body) } + if !strings.Contains(body, `"output_tokens":29`) { + t.Fatalf("expected claude stream usage to preserve output tokens, got: %s", body) + } } func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) { @@ -35,12 +39,16 @@ func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(200) _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gemini-2.5-pro\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gemini-2.5-pro\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":29,\"total_tokens\":40}}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) body := rec.Body.String() if !strings.Contains(body, "candidates") { t.Fatalf("expected gemini stream payload, got: %s", body) } + if !strings.Contains(body, `"promptTokenCount":11`) || !strings.Contains(body, `"candidatesTokenCount":29`) { + t.Fatalf("expected gemini stream usageMetadata to preserve usage, got: %s", body) + } } func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) {