diff --git a/internal/toolcall/toolcalls_parse.go b/internal/toolcall/toolcalls_parse.go index 400fd86..bb1250e 100644 --- a/internal/toolcall/toolcalls_parse.go +++ b/internal/toolcall/toolcalls_parse.go @@ -2,6 +2,7 @@ package toolcall import ( "encoding/json" + "html" "strings" ) @@ -156,14 +157,39 @@ func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []strin if tc.Name == "" { continue } + tc.Name = html.UnescapeString(tc.Name) if tc.Input == nil { tc.Input = map[string]any{} } + for k, v := range tc.Input { + tc.Input[k] = unescapeHTMLValue(v) + } out = append(out, tc) } return out, nil } +func unescapeHTMLValue(v any) any { + switch x := v.(type) { + case string: + return html.UnescapeString(x) + case []any: + out := make([]any, len(x)) + for i := range x { + out[i] = unescapeHTMLValue(x[i]) + } + return out + case map[string]any: + out := make(map[string]any, len(x)) + for k, vv := range x { + out[k] = unescapeHTMLValue(vv) + } + return out + default: + return v + } +} + //nolint:unused // retained for policy-level tool-name matching compatibility. func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string { return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical) diff --git a/internal/toolcall/toolcalls_test.go b/internal/toolcall/toolcalls_test.go index 663b895..e9c23c5 100644 --- a/internal/toolcall/toolcalls_test.go +++ b/internal/toolcall/toolcalls_test.go @@ -691,3 +691,15 @@ func TestRepairLooseJSONWithNestedObjects(t *testing.T) { } } } + +func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) { + text := `Bash{"command":"echo a > out.txt"}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected one call, got %#v", calls) + } + cmd, _ := calls[0].Input["command"].(string) + if cmd != "echo a > out.txt" { + t.Fatalf("expected html entities to be unescaped in command, got %q", cmd) + } +} diff --git a/internal/translatorcliproxy/bridge.go b/internal/translatorcliproxy/bridge.go index e5dc5ac..c5d6741 100644 --- a/internal/translatorcliproxy/bridge.go +++ b/internal/translatorcliproxy/bridge.go @@ -3,6 +3,7 @@ package translatorcliproxy import ( "bytes" "context" + "encoding/json" "strings" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -15,7 +16,12 @@ func ToOpenAI(from sdktranslator.Format, model string, raw []byte, stream bool) func FromOpenAINonStream(to sdktranslator.Format, model string, originalReq, translatedReq, raw []byte) []byte { var param any - return sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m) + converted := sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m) + usage, ok := extractOpenAIUsageFromJSON(raw) + if !ok { + return converted + } + return injectNonStreamUsageMetadata(converted, to, usage) } func FromOpenAIStream(to sdktranslator.Format, model string, originalReq, translatedReq, streamBody []byte) []byte { @@ -65,3 +71,57 @@ func ParseFormat(name string) sdktranslator.Format { func ToOpenAIByName(formatName, model string, raw []byte, stream bool) []byte { return ToOpenAI(ParseFormat(formatName), model, raw, stream) } + +func extractOpenAIUsageFromJSON(raw []byte) (openAIUsage, bool) { + payload := map[string]any{} + if err := json.Unmarshal(raw, &payload); err != nil { + return openAIUsage{}, false + } + usageObj, _ := payload["usage"].(map[string]any) + if usageObj == nil { + return openAIUsage{}, false + } + p := toInt(usageObj["prompt_tokens"]) + c := toInt(usageObj["completion_tokens"]) + t := toInt(usageObj["total_tokens"]) + if p <= 0 { + p = toInt(usageObj["input_tokens"]) + } + if c <= 0 { + c = toInt(usageObj["output_tokens"]) + } + if t <= 0 { + t = p + c + } + if p <= 0 && c <= 0 && t <= 0 { + return openAIUsage{}, false + } + return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true +} + +func injectNonStreamUsageMetadata(converted []byte, target sdktranslator.Format, usage openAIUsage) []byte { + obj := map[string]any{} + if err := json.Unmarshal(converted, &obj); err != nil { + return converted + } + switch target { + case sdktranslator.FormatClaude: + obj["usage"] = map[string]any{ + "input_tokens": usage.PromptTokens, + "output_tokens": usage.CompletionTokens, + } + case sdktranslator.FormatGemini: + obj["usageMetadata"] = map[string]any{ + "promptTokenCount": usage.PromptTokens, + "candidatesTokenCount": usage.CompletionTokens, + "totalTokenCount": usage.TotalTokens, + } + default: + return converted + } + out, err := json.Marshal(obj) + if err != nil { + return converted + } + return out +} diff --git a/internal/translatorcliproxy/bridge_test.go b/internal/translatorcliproxy/bridge_test.go index cdd9cf7..9dbfe30 100644 --- a/internal/translatorcliproxy/bridge_test.go +++ b/internal/translatorcliproxy/bridge_test.go @@ -46,6 +46,22 @@ func TestFromOpenAINonStreamGeminiPreservesUsageFromOpenAI(t *testing.T) { } } +func TestFromOpenAINonStreamPreservesResponsesUsageShape(t *testing.T) { + original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + translatedReq := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":false}`) + openaibody := []byte(`{"id":"resp_1","object":"response","model":"gemini-2.5-pro","usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`) + gotGemini := string(FromOpenAINonStream(sdktranslator.FormatGemini, "gemini-2.5-pro", original, translatedReq, openaibody)) + if !strings.Contains(gotGemini, `"promptTokenCount":11`) || !strings.Contains(gotGemini, `"candidatesTokenCount":29`) || !strings.Contains(gotGemini, `"totalTokenCount":40`) { + t.Fatalf("expected gemini usageMetadata from input/output usage fields, got: %s", gotGemini) + } + + origClaude := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`) + gotClaude := string(FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", origClaude, origClaude, openaibody)) + if !strings.Contains(gotClaude, `"input_tokens":11`) || !strings.Contains(gotClaude, `"output_tokens":29`) { + t.Fatalf("expected claude usage from input/output usage fields, got: %s", gotClaude) + } +} + func TestParseFormatAliases(t *testing.T) { cases := map[string]sdktranslator.Format{ "responses": sdktranslator.FormatOpenAIResponse, diff --git a/internal/translatorcliproxy/stream_writer.go b/internal/translatorcliproxy/stream_writer.go index e80ce69..ac7fc41 100644 --- a/internal/translatorcliproxy/stream_writer.go +++ b/internal/translatorcliproxy/stream_writer.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "net/http" + "strconv" "strings" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -149,6 +150,12 @@ func extractOpenAIUsage(line []byte) (openAIUsage, bool) { p := toInt(usageObj["prompt_tokens"]) c := toInt(usageObj["completion_tokens"]) t := toInt(usageObj["total_tokens"]) + if p <= 0 { + p = toInt(usageObj["input_tokens"]) + } + if c <= 0 { + c = toInt(usageObj["output_tokens"]) + } if p <= 0 && c <= 0 && t <= 0 { return openAIUsage{}, false } @@ -221,6 +228,12 @@ func toInt(v any) int { return int(x) case float32: return int(x) + case string: + n, err := strconv.Atoi(strings.TrimSpace(x)) + if err != nil { + return 0 + } + return n default: return 0 } diff --git a/internal/translatorcliproxy/stream_writer_test.go b/internal/translatorcliproxy/stream_writer_test.go index 94d70b8..f4758d4 100644 --- a/internal/translatorcliproxy/stream_writer_test.go +++ b/internal/translatorcliproxy/stream_writer_test.go @@ -75,3 +75,14 @@ func TestInjectStreamUsageMetadataPreservesSSEFrameTerminator(t *testing.T) { t.Fatalf("expected usageMetadata injected, got %q", string(got)) } } + +func TestExtractOpenAIUsageSupportsResponsesUsageFields(t *testing.T) { + line := []byte(`data: {"usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`) + got, ok := extractOpenAIUsage(line) + if !ok { + t.Fatal("expected usage extracted from input/output usage fields") + } + if got.PromptTokens != 11 || got.CompletionTokens != 29 || got.TotalTokens != 40 { + t.Fatalf("unexpected usage extracted: %#v", got) + } +}