mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
Merge pull request #239 from CJackHwang/codex/fix-escaping-issues-and-token-counting
Fix HTML-escaped tool-call args and preserve upstream token usage (stream & non-stream)
This commit is contained in:
@@ -2,6 +2,7 @@ package toolcall
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"html"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
@@ -92,7 +93,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||
}
|
||||
|
||||
func parseMarkupInput(raw string) map[string]any {
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimSpace(html.UnescapeString(raw))
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
@@ -102,7 +103,7 @@ func parseMarkupInput(raw string) map[string]any {
|
||||
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
return map[string]any{"_raw": stripTagText(raw)}
|
||||
return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))}
|
||||
}
|
||||
|
||||
func parseMarkupKVObject(text string) map[string]any {
|
||||
@@ -123,7 +124,7 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
if !strings.EqualFold(key, endKey) {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(stripTagText(m[2]))
|
||||
value := strings.TrimSpace(html.UnescapeString(stripTagText(m[2])))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package toolcall
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"html"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
@@ -114,10 +115,11 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
if err := dec.DecodeElement(&node, &t); err == nil {
|
||||
inner := strings.TrimSpace(node.Inner)
|
||||
if inner != "" {
|
||||
if parsed := parseToolCallInput(inner); len(parsed) > 0 {
|
||||
unescapedInner := html.UnescapeString(inner)
|
||||
if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if _, onlyRaw := parsed["_raw"]; onlyRaw {
|
||||
if kv := parseMarkupKVObject(inner); len(kv) > 0 {
|
||||
if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 {
|
||||
for k, vv := range kv {
|
||||
params[k] = vv
|
||||
}
|
||||
@@ -128,7 +130,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
|
||||
} else if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 {
|
||||
for k, vv := range kv {
|
||||
params[k] = vv
|
||||
}
|
||||
@@ -143,12 +145,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
name = strings.TrimSpace(v)
|
||||
name = strings.TrimSpace(html.UnescapeString(v))
|
||||
}
|
||||
case "input", "arguments", "argument", "args", "params":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||
if parsed := parseToolCallInput(strings.TrimSpace(html.UnescapeString(v))); len(parsed) > 0 {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
@@ -158,7 +160,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
if inParams || inTool {
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -173,12 +175,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))
|
||||
name = strings.TrimSpace(html.UnescapeString(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner))))
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||
return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true
|
||||
}
|
||||
|
||||
func stripTopLevelXMLParameters(inner string) string {
|
||||
@@ -231,7 +233,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||
if len(m) < 2 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
@@ -241,7 +243,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(pm[1])
|
||||
val := strings.TrimSpace(pm[2])
|
||||
val := strings.TrimSpace(html.UnescapeString(pm[2]))
|
||||
if key != "" {
|
||||
input[key] = val
|
||||
}
|
||||
@@ -270,11 +272,11 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(m[2])
|
||||
body := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
input := map[string]any{}
|
||||
if strings.HasPrefix(body, "{") {
|
||||
if err := json.Unmarshal([]byte(body), &input); err == nil {
|
||||
@@ -291,7 +293,7 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(am[1])
|
||||
v := strings.TrimSpace(am[2])
|
||||
v := strings.TrimSpace(html.UnescapeString(am[2]))
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
}
|
||||
@@ -304,7 +306,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
@@ -314,7 +316,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := strings.TrimSpace(pm[2])
|
||||
v := strings.TrimSpace(html.UnescapeString(pm[2]))
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
}
|
||||
@@ -334,7 +336,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
@@ -345,7 +347,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := strings.TrimSpace(pm[2])
|
||||
v := strings.TrimSpace(html.UnescapeString(pm[2]))
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
}
|
||||
@@ -358,11 +360,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) {
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
raw := strings.TrimSpace(m[2])
|
||||
raw := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
input := map[string]any{}
|
||||
if raw != "" {
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
@@ -379,11 +381,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool)
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
raw := strings.TrimSpace(m[2])
|
||||
raw := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
input := map[string]any{}
|
||||
if raw != "" {
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
@@ -400,11 +402,11 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) {
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(m[1]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(m[2])
|
||||
body := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
input := map[string]any{}
|
||||
if body != "" {
|
||||
if kv := parseXMLChildKV(body); len(kv) > 0 {
|
||||
|
||||
@@ -691,3 +691,27 @@ func TestRepairLooseJSONWithNestedObjects(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) {
|
||||
text := `<tool_call><tool_name>Bash</tool_name><parameters>{"command":"echo a > out.txt"}</parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one call, got %#v", calls)
|
||||
}
|
||||
cmd, _ := calls[0].Input["command"].(string)
|
||||
if cmd != "echo a > out.txt" {
|
||||
t.Fatalf("expected html entities to be unescaped in command, got %q", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsJSONPayloadKeepsLiteralEntities(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"bash","input":{"command":"echo > literally"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one call, got %#v", calls)
|
||||
}
|
||||
cmd, _ := calls[0].Input["command"].(string)
|
||||
if cmd != "echo > literally" {
|
||||
t.Fatalf("expected json payload to keep literal entities, got %q", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package translatorcliproxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -15,7 +16,12 @@ func ToOpenAI(from sdktranslator.Format, model string, raw []byte, stream bool)
|
||||
|
||||
func FromOpenAINonStream(to sdktranslator.Format, model string, originalReq, translatedReq, raw []byte) []byte {
|
||||
var param any
|
||||
return sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m)
|
||||
usage, ok := extractOpenAIUsageFromJSON(raw)
|
||||
if !ok {
|
||||
return converted
|
||||
}
|
||||
return injectNonStreamUsageMetadata(converted, to, usage)
|
||||
}
|
||||
|
||||
func FromOpenAIStream(to sdktranslator.Format, model string, originalReq, translatedReq, streamBody []byte) []byte {
|
||||
@@ -65,3 +71,57 @@ func ParseFormat(name string) sdktranslator.Format {
|
||||
func ToOpenAIByName(formatName, model string, raw []byte, stream bool) []byte {
|
||||
return ToOpenAI(ParseFormat(formatName), model, raw, stream)
|
||||
}
|
||||
|
||||
func extractOpenAIUsageFromJSON(raw []byte) (openAIUsage, bool) {
|
||||
payload := map[string]any{}
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return openAIUsage{}, false
|
||||
}
|
||||
usageObj, _ := payload["usage"].(map[string]any)
|
||||
if usageObj == nil {
|
||||
return openAIUsage{}, false
|
||||
}
|
||||
p := toInt(usageObj["prompt_tokens"])
|
||||
c := toInt(usageObj["completion_tokens"])
|
||||
t := toInt(usageObj["total_tokens"])
|
||||
if p <= 0 {
|
||||
p = toInt(usageObj["input_tokens"])
|
||||
}
|
||||
if c <= 0 {
|
||||
c = toInt(usageObj["output_tokens"])
|
||||
}
|
||||
if t <= 0 {
|
||||
t = p + c
|
||||
}
|
||||
if p <= 0 && c <= 0 && t <= 0 {
|
||||
return openAIUsage{}, false
|
||||
}
|
||||
return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true
|
||||
}
|
||||
|
||||
func injectNonStreamUsageMetadata(converted []byte, target sdktranslator.Format, usage openAIUsage) []byte {
|
||||
obj := map[string]any{}
|
||||
if err := json.Unmarshal(converted, &obj); err != nil {
|
||||
return converted
|
||||
}
|
||||
switch target {
|
||||
case sdktranslator.FormatClaude:
|
||||
obj["usage"] = map[string]any{
|
||||
"input_tokens": usage.PromptTokens,
|
||||
"output_tokens": usage.CompletionTokens,
|
||||
}
|
||||
case sdktranslator.FormatGemini:
|
||||
obj["usageMetadata"] = map[string]any{
|
||||
"promptTokenCount": usage.PromptTokens,
|
||||
"candidatesTokenCount": usage.CompletionTokens,
|
||||
"totalTokenCount": usage.TotalTokens,
|
||||
}
|
||||
default:
|
||||
return converted
|
||||
}
|
||||
out, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return converted
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -46,6 +46,22 @@ func TestFromOpenAINonStreamGeminiPreservesUsageFromOpenAI(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromOpenAINonStreamPreservesResponsesUsageShape(t *testing.T) {
|
||||
original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)
|
||||
translatedReq := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
openaibody := []byte(`{"id":"resp_1","object":"response","model":"gemini-2.5-pro","usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`)
|
||||
gotGemini := string(FromOpenAINonStream(sdktranslator.FormatGemini, "gemini-2.5-pro", original, translatedReq, openaibody))
|
||||
if !strings.Contains(gotGemini, `"promptTokenCount":11`) || !strings.Contains(gotGemini, `"candidatesTokenCount":29`) || !strings.Contains(gotGemini, `"totalTokenCount":40`) {
|
||||
t.Fatalf("expected gemini usageMetadata from input/output usage fields, got: %s", gotGemini)
|
||||
}
|
||||
|
||||
origClaude := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
gotClaude := string(FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", origClaude, origClaude, openaibody))
|
||||
if !strings.Contains(gotClaude, `"input_tokens":11`) || !strings.Contains(gotClaude, `"output_tokens":29`) {
|
||||
t.Fatalf("expected claude usage from input/output usage fields, got: %s", gotClaude)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFormatAliases(t *testing.T) {
|
||||
cases := map[string]sdktranslator.Format{
|
||||
"responses": sdktranslator.FormatOpenAIResponse,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -149,6 +150,12 @@ func extractOpenAIUsage(line []byte) (openAIUsage, bool) {
|
||||
p := toInt(usageObj["prompt_tokens"])
|
||||
c := toInt(usageObj["completion_tokens"])
|
||||
t := toInt(usageObj["total_tokens"])
|
||||
if p <= 0 {
|
||||
p = toInt(usageObj["input_tokens"])
|
||||
}
|
||||
if c <= 0 {
|
||||
c = toInt(usageObj["output_tokens"])
|
||||
}
|
||||
if p <= 0 && c <= 0 && t <= 0 {
|
||||
return openAIUsage{}, false
|
||||
}
|
||||
@@ -221,6 +228,12 @@ func toInt(v any) int {
|
||||
return int(x)
|
||||
case float32:
|
||||
return int(x)
|
||||
case string:
|
||||
n, err := strconv.Atoi(strings.TrimSpace(x))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -75,3 +75,14 @@ func TestInjectStreamUsageMetadataPreservesSSEFrameTerminator(t *testing.T) {
|
||||
t.Fatalf("expected usageMetadata injected, got %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAIUsageSupportsResponsesUsageFields(t *testing.T) {
|
||||
line := []byte(`data: {"usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`)
|
||||
got, ok := extractOpenAIUsage(line)
|
||||
if !ok {
|
||||
t.Fatal("expected usage extracted from input/output usage fields")
|
||||
}
|
||||
if got.PromptTokens != 11 || got.CompletionTokens != 29 || got.TotalTokens != 40 {
|
||||
t.Fatalf("unexpected usage extracted: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user