From da7c46b2786767576bd4d5c389beae814b2cc815 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Tue, 7 Apr 2026 12:55:06 +0800 Subject: [PATCH] Limit HTML unescape to markup tool-call parsing --- internal/toolcall/toolcalls_markup.go | 7 +-- internal/toolcall/toolcalls_parse.go | 26 ----------- internal/toolcall/toolcalls_parse_markup.go | 48 +++++++++++---------- internal/toolcall/toolcalls_test.go | 12 ++++++ 4 files changed, 41 insertions(+), 52 deletions(-) diff --git a/internal/toolcall/toolcalls_markup.go b/internal/toolcall/toolcalls_markup.go index 9a6ad2c..56546eb 100644 --- a/internal/toolcall/toolcalls_markup.go +++ b/internal/toolcall/toolcalls_markup.go @@ -2,6 +2,7 @@ package toolcall import ( "encoding/json" + "html" "regexp" "strings" ) @@ -92,7 +93,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall { } func parseMarkupInput(raw string) map[string]any { - raw = strings.TrimSpace(raw) + raw = strings.TrimSpace(html.UnescapeString(raw)) if raw == "" { return map[string]any{} } @@ -102,7 +103,7 @@ func parseMarkupInput(raw string) map[string]any { if kv := parseMarkupKVObject(raw); len(kv) > 0 { return kv } - return map[string]any{"_raw": stripTagText(raw)} + return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))} } func parseMarkupKVObject(text string) map[string]any { @@ -123,7 +124,7 @@ func parseMarkupKVObject(text string) map[string]any { if !strings.EqualFold(key, endKey) { continue } - value := strings.TrimSpace(stripTagText(m[2])) + value := strings.TrimSpace(html.UnescapeString(stripTagText(m[2]))) if value == "" { continue } diff --git a/internal/toolcall/toolcalls_parse.go b/internal/toolcall/toolcalls_parse.go index bb1250e..400fd86 100644 --- a/internal/toolcall/toolcalls_parse.go +++ b/internal/toolcall/toolcalls_parse.go @@ -2,7 +2,6 @@ package toolcall import ( "encoding/json" - "html" "strings" ) @@ -157,39 +156,14 @@ func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []strin if tc.Name == "" { continue } - tc.Name = html.UnescapeString(tc.Name) if tc.Input == nil { tc.Input = map[string]any{} } - for k, v := range tc.Input { - tc.Input[k] = unescapeHTMLValue(v) - } out = append(out, tc) } return out, nil } -func unescapeHTMLValue(v any) any { - switch x := v.(type) { - case string: - return html.UnescapeString(x) - case []any: - out := make([]any, len(x)) - for i := range x { - out[i] = unescapeHTMLValue(x[i]) - } - return out - case map[string]any: - out := make(map[string]any, len(x)) - for k, vv := range x { - out[k] = unescapeHTMLValue(vv) - } - return out - default: - return v - } -} - //nolint:unused // retained for policy-level tool-name matching compatibility. func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string { return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical) diff --git a/internal/toolcall/toolcalls_parse_markup.go b/internal/toolcall/toolcalls_parse_markup.go index fa41036..d269e40 100644 --- a/internal/toolcall/toolcalls_parse_markup.go +++ b/internal/toolcall/toolcalls_parse_markup.go @@ -3,6 +3,7 @@ package toolcall import ( "encoding/json" "encoding/xml" + "html" "regexp" "strings" ) @@ -114,10 +115,11 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { if err := dec.DecodeElement(&node, &t); err == nil { inner := strings.TrimSpace(node.Inner) if inner != "" { - if parsed := parseToolCallInput(inner); len(parsed) > 0 { + unescapedInner := html.UnescapeString(inner) + if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 { if len(parsed) == 1 { if _, onlyRaw := parsed["_raw"]; onlyRaw { - if kv := parseMarkupKVObject(inner); len(kv) > 0 { + if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 { for k, vv := range kv { params[k] = vv } @@ -128,7 +130,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { for k, vv := range parsed { params[k] = vv } - } else if kv := parseMarkupKVObject(inner); len(kv) > 0 { + } else if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 { for k, vv := range kv { params[k] = vv } @@ -143,12 +145,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { params[t.Name.Local] = strings.TrimSpace(v) break } - name = strings.TrimSpace(v) + name = strings.TrimSpace(html.UnescapeString(v)) } case "input", "arguments", "argument", "args", "params": var v string if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" { - if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 { + if parsed := parseToolCallInput(strings.TrimSpace(html.UnescapeString(v))); len(parsed) > 0 { for k, vv := range parsed { params[k] = vv } @@ -158,7 +160,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { if inParams || inTool { var v string if err := dec.DecodeElement(&v, &t); err == nil { - params[t.Name.Local] = strings.TrimSpace(v) + params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v)) } } } @@ -173,12 +175,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { } } if strings.TrimSpace(name) == "" { - name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner))) + name = strings.TrimSpace(html.UnescapeString(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))) } if strings.TrimSpace(name) == "" { return ParsedToolCall{}, false } - return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true + return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true } func stripTopLevelXMLParameters(inner string) string { @@ -231,7 +233,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { if len(m) < 2 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } @@ -241,7 +243,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { continue } key := strings.TrimSpace(pm[1]) - val := strings.TrimSpace(pm[2]) + val := strings.TrimSpace(html.UnescapeString(pm[2])) if key != "" { input[key] = val } @@ -270,11 +272,11 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - body := strings.TrimSpace(m[2]) + body := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if strings.HasPrefix(body, "{") { if err := json.Unmarshal([]byte(body), &input); err == nil { @@ -291,7 +293,7 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(am[1]) - v := strings.TrimSpace(am[2]) + v := strings.TrimSpace(html.UnescapeString(am[2])) if k != "" { input[k] = v } @@ -304,7 +306,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } @@ -314,7 +316,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(pm[2]) + v := strings.TrimSpace(html.UnescapeString(pm[2])) if k != "" { input[k] = v } @@ -334,7 +336,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } @@ -345,7 +347,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(pm[2]) + v := strings.TrimSpace(html.UnescapeString(pm[2])) if k != "" { input[k] = v } @@ -358,11 +360,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - raw := strings.TrimSpace(m[2]) + raw := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if raw != "" { if parsed := parseToolCallInput(raw); len(parsed) > 0 { @@ -379,11 +381,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool) if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - raw := strings.TrimSpace(m[2]) + raw := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if raw != "" { if parsed := parseToolCallInput(raw); len(parsed) > 0 { @@ -400,11 +402,11 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - body := strings.TrimSpace(m[2]) + body := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if body != "" { if kv := parseXMLChildKV(body); len(kv) > 0 { diff --git a/internal/toolcall/toolcalls_test.go b/internal/toolcall/toolcalls_test.go index e9c23c5..faa7322 100644 --- a/internal/toolcall/toolcalls_test.go +++ b/internal/toolcall/toolcalls_test.go @@ -703,3 +703,15 @@ func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) { t.Fatalf("expected html entities to be unescaped in command, got %q", cmd) } } + +func TestParseToolCallsJSONPayloadKeepsLiteralEntities(t *testing.T) { + text := `{"tool_calls":[{"name":"bash","input":{"command":"echo > literally"}}]}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected one call, got %#v", calls) + } + cmd, _ := calls[0].Input["command"].(string) + if cmd != "echo > literally" { + t.Fatalf("expected json payload to keep literal entities, got %q", cmd) + } +}