Limit HTML unescape to markup tool-call parsing

This commit is contained in:
CJACK.
2026-04-07 12:55:06 +08:00
parent 77a401fb19
commit da7c46b278
4 changed files with 41 additions and 52 deletions

View File

@@ -2,6 +2,7 @@ package toolcall
import (
"encoding/json"
"html"
"regexp"
"strings"
)
@@ -92,7 +93,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
}
func parseMarkupInput(raw string) map[string]any {
raw = strings.TrimSpace(raw)
raw = strings.TrimSpace(html.UnescapeString(raw))
if raw == "" {
return map[string]any{}
}
@@ -102,7 +103,7 @@ func parseMarkupInput(raw string) map[string]any {
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
return kv
}
return map[string]any{"_raw": stripTagText(raw)}
return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))}
}
func parseMarkupKVObject(text string) map[string]any {
@@ -123,7 +124,7 @@ func parseMarkupKVObject(text string) map[string]any {
if !strings.EqualFold(key, endKey) {
continue
}
value := strings.TrimSpace(stripTagText(m[2]))
value := strings.TrimSpace(html.UnescapeString(stripTagText(m[2])))
if value == "" {
continue
}

View File

@@ -2,7 +2,6 @@ package toolcall
import (
"encoding/json"
"html"
"strings"
)
@@ -157,39 +156,14 @@ func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []strin
if tc.Name == "" {
continue
}
tc.Name = html.UnescapeString(tc.Name)
if tc.Input == nil {
tc.Input = map[string]any{}
}
for k, v := range tc.Input {
tc.Input[k] = unescapeHTMLValue(v)
}
out = append(out, tc)
}
return out, nil
}
func unescapeHTMLValue(v any) any {
switch x := v.(type) {
case string:
return html.UnescapeString(x)
case []any:
out := make([]any, len(x))
for i := range x {
out[i] = unescapeHTMLValue(x[i])
}
return out
case map[string]any:
out := make(map[string]any, len(x))
for k, vv := range x {
out[k] = unescapeHTMLValue(vv)
}
return out
default:
return v
}
}
//nolint:unused // retained for policy-level tool-name matching compatibility.
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical)

View File

@@ -3,6 +3,7 @@ package toolcall
import (
"encoding/json"
"encoding/xml"
"html"
"regexp"
"strings"
)
@@ -114,10 +115,11 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
if err := dec.DecodeElement(&node, &t); err == nil {
inner := strings.TrimSpace(node.Inner)
if inner != "" {
if parsed := parseToolCallInput(inner); len(parsed) > 0 {
unescapedInner := html.UnescapeString(inner)
if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 {
if len(parsed) == 1 {
if _, onlyRaw := parsed["_raw"]; onlyRaw {
if kv := parseMarkupKVObject(inner); len(kv) > 0 {
if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 {
for k, vv := range kv {
params[k] = vv
}
@@ -128,7 +130,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
for k, vv := range parsed {
params[k] = vv
}
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
} else if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 {
for k, vv := range kv {
params[k] = vv
}
@@ -143,12 +145,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
params[t.Name.Local] = strings.TrimSpace(v)
break
}
name = strings.TrimSpace(v)
name = strings.TrimSpace(html.UnescapeString(v))
}
case "input", "arguments", "argument", "args", "params":
var v string
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
if parsed := parseToolCallInput(strings.TrimSpace(html.UnescapeString(v))); len(parsed) > 0 {
for k, vv := range parsed {
params[k] = vv
}
@@ -158,7 +160,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
if inParams || inTool {
var v string
if err := dec.DecodeElement(&v, &t); err == nil {
params[t.Name.Local] = strings.TrimSpace(v)
params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v))
}
}
}
@@ -173,12 +175,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
}
}
if strings.TrimSpace(name) == "" {
name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))
name = strings.TrimSpace(html.UnescapeString(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner))))
}
if strings.TrimSpace(name) == "" {
return ParsedToolCall{}, false
}
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true
}
func stripTopLevelXMLParameters(inner string) string {
@@ -231,7 +233,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
if len(m) < 2 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
@@ -241,7 +243,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
continue
}
key := strings.TrimSpace(pm[1])
val := strings.TrimSpace(pm[2])
val := strings.TrimSpace(html.UnescapeString(pm[2]))
if key != "" {
input[key] = val
}
@@ -270,11 +272,11 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
body := strings.TrimSpace(m[2])
body := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if strings.HasPrefix(body, "{") {
if err := json.Unmarshal([]byte(body), &input); err == nil {
@@ -291,7 +293,7 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(am[1])
v := strings.TrimSpace(am[2])
v := strings.TrimSpace(html.UnescapeString(am[2]))
if k != "" {
input[k] = v
}
@@ -304,7 +306,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
@@ -314,7 +316,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(pm[2])
v := strings.TrimSpace(html.UnescapeString(pm[2]))
if k != "" {
input[k] = v
}
@@ -334,7 +336,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
@@ -345,7 +347,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(pm[2])
v := strings.TrimSpace(html.UnescapeString(pm[2]))
if k != "" {
input[k] = v
}
@@ -358,11 +360,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
raw := strings.TrimSpace(m[2])
raw := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if raw != "" {
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
@@ -379,11 +381,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool)
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
raw := strings.TrimSpace(m[2])
raw := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if raw != "" {
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
@@ -400,11 +402,11 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
body := strings.TrimSpace(m[2])
body := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if body != "" {
if kv := parseXMLChildKV(body); len(kv) > 0 {

View File

@@ -703,3 +703,15 @@ func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) {
t.Fatalf("expected html entities to be unescaped in command, got %q", cmd)
}
}
func TestParseToolCallsJSONPayloadKeepsLiteralEntities(t *testing.T) {
text := `{"tool_calls":[{"name":"bash","input":{"command":"echo &gt; literally"}}]}`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected one call, got %#v", calls)
}
cmd, _ := calls[0].Input["command"].(string)
if cmd != "echo &gt; literally" {
t.Fatalf("expected json payload to keep literal entities, got %q", cmd)
}
}