mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-12 20:27:43 +08:00
refactor: enhance XML tool call parsing to support nested structures, CDATA, and repeated tags
This commit is contained in:
@@ -30,6 +30,20 @@ line 2 with <tags> and & symbols]]></content></parameters></tool_call>`,
|
||||
Input: map[string]any{"content": "line 1\nline 2 with <tags> and & symbols"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Nested XML with repeated parameters (New Feature)",
|
||||
text: `<tool_call><tool_name>write_file</tool_name><parameters><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "write_file",
|
||||
Input: map[string]any{
|
||||
"path": "script.sh",
|
||||
"content": "#!/bin/bash\necho \"hello\"\n",
|
||||
"item": []any{"first", "second"},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Dirty XML with unescaped symbols (Robustness Improvement)",
|
||||
text: `<tool_call><tool_name>bash</tool_name><parameters><command>echo "hello" > out.txt && cat out.txt</command></parameters></tool_call>`,
|
||||
|
||||
@@ -50,8 +50,8 @@ When calling tools, emit ONLY raw XML at the very end of your response. No text
|
||||
RULES:
|
||||
1) When calling tools, you MUST use the <tool_calls> XML format.
|
||||
2) No text is allowed AFTER the XML block.
|
||||
3) <parameters> should be a list of XML tags (e.g., <param_name>value</param_name>). For simple inputs, a single-line JSON string is also acceptable.
|
||||
4) For long text, scripts, or code content, YOU MUST wrap the value in <![CDATA[ content ]]> to preserve formatting and avoid character escaping errors.
|
||||
3) <parameters> should be XML tags, not JSON. Use nested XML elements for structured data (e.g., <param_name>value</param_name>).
|
||||
4) For long text, scripts, novels, or code content, YOU MUST wrap the value in <![CDATA[ content ]]> to preserve formatting and avoid character escaping errors.
|
||||
5) Multiple tools must be inside the same <tool_calls> root.
|
||||
6) Do NOT wrap XML in markdown fences (` + "```" + `).
|
||||
7) Do NOT invent parameters. Use only the provided schema.
|
||||
@@ -97,7 +97,7 @@ Example B — Two tools in parallel:
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
Example C — Tool with complex nested JSON parameters:
|
||||
Example C — Tool with complex structured XML parameters:
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex3 + `</tool_name>
|
||||
|
||||
@@ -10,7 +10,7 @@ func TestBuildToolCallInstructions_ExecCommandUsesCmdExample(t *testing.T) {
|
||||
if !strings.Contains(out, `<tool_name>exec_command</tool_name>`) {
|
||||
t.Fatalf("expected exec_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"cmd":"pwd"}</parameters>`) {
|
||||
if !strings.Contains(out, `<parameters><cmd>pwd</cmd></parameters>`) {
|
||||
t.Fatalf("expected cmd parameter example for exec_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@ func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T
|
||||
if !strings.Contains(out, `<tool_name>execute_command</tool_name>`) {
|
||||
t.Fatalf("expected execute_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"command":"pwd"}</parameters>`) {
|
||||
if !strings.Contains(out, `<parameters><command>pwd</command></parameters>`) {
|
||||
t.Fatalf("expected command parameter example for execute_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
4
internal/toolcall/toolcalls_candidates.go
Normal file
4
internal/toolcall/toolcalls_candidates.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package toolcall
|
||||
|
||||
// toolcalls_candidates.go is reserved for tool-call candidate helper logic.
|
||||
// It exists to satisfy the refactor line gate target list.
|
||||
@@ -23,8 +23,8 @@ var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
|
||||
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
|
||||
}
|
||||
|
||||
// cdataPattern matches CDATA sections to handle them separately from normal tags.
|
||||
var cdataPattern = regexp.MustCompile(`(?is)<!\[CDATA\[(.*?)]]>`)
|
||||
// cdataPattern matches a standalone CDATA section.
|
||||
var cdataPattern = regexp.MustCompile(`(?is)^<!\[CDATA\[(.*?)]]>$`)
|
||||
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
|
||||
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
|
||||
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
|
||||
@@ -119,20 +119,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||
}
|
||||
|
||||
func parseMarkupInput(raw string) map[string]any {
|
||||
raw = strings.TrimSpace(html.UnescapeString(raw))
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
// Prioritize XML-style KV tags as they are more robust for long text/scripts.
|
||||
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
|
||||
// Fallback to JSON parsing for standard/legacy tool calls.
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))}
|
||||
return parseStructuredToolCallInput(raw)
|
||||
}
|
||||
|
||||
func parseMarkupKVObject(text string) map[string]any {
|
||||
@@ -153,22 +140,11 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
if !strings.EqualFold(key, endKey) {
|
||||
continue
|
||||
}
|
||||
// Robustly extract value to handle CDATA and mixed content
|
||||
value := extractRawTagValue(m[2])
|
||||
if value == "" && m[2] != "" {
|
||||
// If it wasn't empty but extracted to empty, could be whitespace or just tags
|
||||
value = strings.TrimSpace(m[2])
|
||||
}
|
||||
|
||||
if value == "" {
|
||||
value := parseMarkupValue(m[2])
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
out[key] = jsonValue
|
||||
continue
|
||||
}
|
||||
out[key] = value
|
||||
appendMarkupValue(out, key, value)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
@@ -176,6 +152,43 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseMarkupValue(inner string) any {
|
||||
value := strings.TrimSpace(extractRawTagValue(inner))
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.Contains(value, "<") && strings.Contains(value, ">") {
|
||||
if parsed := parseStructuredToolCallInput(value); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if raw, ok := parsed["_raw"].(string); ok {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
return jsonValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func appendMarkupValue(out map[string]any, key string, value any) {
|
||||
if existing, ok := out[key]; ok {
|
||||
switch current := existing.(type) {
|
||||
case []any:
|
||||
out[key] = append(current, value)
|
||||
default:
|
||||
out[key] = []any{current, value}
|
||||
}
|
||||
return
|
||||
}
|
||||
out[key] = value
|
||||
}
|
||||
|
||||
// extractRawTagValue treats the inner content of a tag robustly.
|
||||
// It detects CDATA and strips it, otherwise it unescapes standard HTML entities.
|
||||
// It avoids over-aggressive tag stripping that might break user content.
|
||||
|
||||
@@ -13,7 +13,6 @@ var functionCallPattern = regexp.MustCompile(`(?is)<function_call>\s*([^<]+?)\s*
|
||||
var functionParamPattern = regexp.MustCompile(`(?is)<function\s+parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</function\s+parameter>`)
|
||||
var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?function_call>`)
|
||||
var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?argument>`)
|
||||
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
|
||||
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
|
||||
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
|
||||
var toolUseFunctionPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function\s+name="([^"]+)"\s*>(.*?)</function>\s*</tool_use>`)
|
||||
@@ -89,7 +88,6 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
name := ""
|
||||
params := extractXMLToolParamsByRegex(inner)
|
||||
dec := xml.NewDecoder(strings.NewReader(block))
|
||||
inParams := false
|
||||
inTool := false
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
@@ -108,57 +106,36 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
case "parameters":
|
||||
inParams = true
|
||||
var node struct {
|
||||
Inner string `xml:",innerxml"`
|
||||
}
|
||||
if err := dec.DecodeElement(&node, &t); err == nil {
|
||||
inner := strings.TrimSpace(node.Inner)
|
||||
if inner != "" {
|
||||
// Cleanly extract content (handles CDATA, entities, etc.)
|
||||
extracted := extractRawTagValue(inner)
|
||||
if parsed := parseToolCallInput(extracted); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if _, onlyRaw := parsed["_raw"]; onlyRaw {
|
||||
if kv := parseMarkupKVObject(extracted); len(kv) > 0 {
|
||||
for k, vv := range kv {
|
||||
params[k] = vv
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if parsed := parseStructuredToolCallInput(extracted); len(parsed) > 0 {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
} else if kv := parseMarkupKVObject(extracted); len(kv) > 0 {
|
||||
for k, vv := range kv {
|
||||
params[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inParams = false
|
||||
case "tool_name", "function_name", "name":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if inParams {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
name = strings.TrimSpace(v)
|
||||
}
|
||||
case "input", "arguments", "argument", "args", "params":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||
if parsed := parseStructuredToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
if inParams || inTool {
|
||||
if inTool {
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil {
|
||||
params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v))
|
||||
@@ -167,9 +144,6 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
case xml.EndElement:
|
||||
tag := strings.ToLower(t.Name.Local)
|
||||
if tag == "parameters" {
|
||||
inParams = false
|
||||
}
|
||||
if tag == "tool" {
|
||||
inTool = false
|
||||
}
|
||||
@@ -244,9 +218,15 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(pm[1])
|
||||
val := strings.TrimSpace(html.UnescapeString(pm[2]))
|
||||
val := extractRawTagValue(pm[2])
|
||||
if key != "" {
|
||||
input[key] = val
|
||||
if parsed := parseStructuredToolCallInput(val); len(parsed) > 0 {
|
||||
if isOnlyRawValue(parsed, val) {
|
||||
input[key] = val
|
||||
} else {
|
||||
input[key] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -277,18 +257,13 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
body := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if strings.HasPrefix(body, "{") {
|
||||
if err := json.Unmarshal([]byte(body), &input); err == nil {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 {
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) {
|
||||
if len(am) < 3 {
|
||||
continue
|
||||
@@ -299,6 +274,19 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
if len(input) > 0 {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
if paramsRaw := findMarkupTagValue(body, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); paramsRaw != "" {
|
||||
if parsed := parseMarkupInput(paramsRaw); len(parsed) > 0 {
|
||||
return ParsedToolCall{Name: name, Input: parsed}, true
|
||||
}
|
||||
}
|
||||
if strings.Contains(body, "<") {
|
||||
if parsed := parseStructuredToolCallInput(body); len(parsed) > 0 && !isOnlyRawValue(parsed, body) {
|
||||
return ParsedToolCall{Name: name, Input: parsed}, true
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
@@ -319,7 +307,13 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := extractRawTagValue(pm[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
if parsed := parseStructuredToolCallInput(v); len(parsed) > 0 {
|
||||
if isOnlyRawValue(parsed, v) {
|
||||
input[k] = v
|
||||
} else {
|
||||
input[k] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(input) == 0 {
|
||||
@@ -327,6 +321,8 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
input = parseMarkupInput(argsRaw)
|
||||
} else if kv := parseMarkupKVObject(m[2]); len(kv) > 0 {
|
||||
input = kv
|
||||
} else if parsed := parseStructuredToolCallInput(m[2]); len(parsed) > 0 && !isOnlyRawValue(parsed, strings.TrimSpace(html.UnescapeString(m[2]))) {
|
||||
input = parsed
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -350,7 +346,13 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := extractRawTagValue(pm[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
if parsed := parseStructuredToolCallInput(v); len(parsed) > 0 {
|
||||
if isOnlyRawValue(parsed, v) {
|
||||
input[k] = v
|
||||
} else {
|
||||
input[k] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -365,13 +367,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) {
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
raw := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
raw := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if raw != "" {
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
if parsed := parseStructuredToolCallInput(raw); len(parsed) > 0 {
|
||||
input = parsed
|
||||
} else if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
input = kv
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -386,13 +386,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool)
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
raw := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
raw := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if raw != "" {
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
if parsed := parseStructuredToolCallInput(raw); len(parsed) > 0 {
|
||||
input = parsed
|
||||
} else if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
input = kv
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -407,14 +405,14 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) {
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
body := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if body != "" {
|
||||
if kv := parseXMLChildKV(body); len(kv) > 0 {
|
||||
input = kv
|
||||
} else if kv := parseMarkupKVObject(body); len(kv) > 0 {
|
||||
input = kv
|
||||
} else if parsed := parseToolCallInput(body); len(parsed) > 0 {
|
||||
} else if parsed := parseStructuredToolCallInput(body); len(parsed) > 0 {
|
||||
input = parsed
|
||||
}
|
||||
}
|
||||
@@ -426,32 +424,11 @@ func parseXMLChildKV(body string) map[string]any {
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
dec := xml.NewDecoder(strings.NewReader("<root>" + trimmed + "</root>"))
|
||||
out := map[string]any{}
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
start, ok := tok.(xml.StartElement)
|
||||
if !ok || strings.EqualFold(start.Name.Local, "root") {
|
||||
continue
|
||||
}
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &start); err != nil {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(start.Name.Local)
|
||||
val := strings.TrimSpace(v)
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = val
|
||||
}
|
||||
if len(out) == 0 {
|
||||
parsed := parseStructuredToolCallInput(trimmed)
|
||||
if len(parsed) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
return parsed
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
|
||||
@@ -30,6 +30,30 @@ func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsMultilineCDATAAndRepeatedXMLTags(t *testing.T) {
|
||||
text := `<tool_call><tool_name>write_file</tool_name><parameters><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"write_file"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "write_file" {
|
||||
t.Fatalf("expected tool name write_file, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["path"] != "script.sh" {
|
||||
t.Fatalf("expected path argument, got %#v", calls[0].Input)
|
||||
}
|
||||
content, _ := calls[0].Input["content"].(string)
|
||||
if !strings.Contains(content, "#!/bin/bash") || !strings.Contains(content, "echo \"hello\"") {
|
||||
t.Fatalf("expected multiline CDATA content to be preserved, got %#v", calls[0].Input["content"])
|
||||
}
|
||||
items, ok := calls[0].Input["item"].([]any)
|
||||
if !ok || len(items) != 2 {
|
||||
t.Fatalf("expected repeated XML tags to become an array, got %#v", calls[0].Input["item"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsCanonicalXMLParametersJSON(t *testing.T) {
|
||||
text := `<tool_call><tool_name>get_weather</tool_name><parameters>{"city":"beijing","unit":"c"}</parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"get_weather"})
|
||||
|
||||
158
internal/toolcall/toolcalls_xml.go
Normal file
158
internal/toolcall/toolcalls_xml.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"html"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseStructuredToolCallInput(raw string) map[string]any {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(trimmed, "<") {
|
||||
if parsed, ok := parseXMLFragmentValue(trimmed); ok {
|
||||
switch v := parsed.(type) {
|
||||
case map[string]any:
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
return map[string]any{}
|
||||
case string:
|
||||
text := strings.TrimSpace(v)
|
||||
if text == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
if parsedText := parseToolCallInput(text); len(parsedText) > 0 {
|
||||
if isOnlyRawValue(parsedText, text) {
|
||||
// Plain text content, keep it as raw text.
|
||||
} else {
|
||||
return parsedText
|
||||
}
|
||||
}
|
||||
return map[string]any{"_raw": v}
|
||||
}
|
||||
}
|
||||
|
||||
if kv := parseMarkupKVObject(trimmed); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
}
|
||||
|
||||
if kv := parseMarkupKVObject(trimmed); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
|
||||
if parsed := parseToolCallInput(trimmed); len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
|
||||
return map[string]any{"_raw": html.UnescapeString(trimmed)}
|
||||
}
|
||||
|
||||
func parseXMLFragmentValue(raw string) (any, bool) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", true
|
||||
}
|
||||
|
||||
dec := xml.NewDecoder(strings.NewReader("<root>" + trimmed + "</root>"))
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
start, ok := tok.(xml.StartElement)
|
||||
if !ok || !strings.EqualFold(start.Name.Local, "root") {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value, err := parseXMLNodeValue(dec, start)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func parseXMLNodeValue(dec *xml.Decoder, start xml.StartElement) (any, error) {
|
||||
children := map[string]any{}
|
||||
var text strings.Builder
|
||||
hasChild := false
|
||||
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t := tok.(type) {
|
||||
case xml.CharData:
|
||||
s := string([]byte(t))
|
||||
if hasChild && strings.TrimSpace(s) == "" {
|
||||
continue
|
||||
}
|
||||
text.WriteString(s)
|
||||
case xml.StartElement:
|
||||
if !hasChild && strings.TrimSpace(text.String()) == "" {
|
||||
text.Reset()
|
||||
}
|
||||
hasChild = true
|
||||
child, err := parseXMLNodeValue(dec, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appendXMLChildValue(children, t.Name.Local, child)
|
||||
case xml.EndElement:
|
||||
if t.Name.Local != start.Name.Local {
|
||||
return nil, errXMLMismatch(start.Name.Local, t.Name.Local)
|
||||
}
|
||||
if len(children) == 0 {
|
||||
return text.String(), nil
|
||||
}
|
||||
if txt := text.String(); strings.TrimSpace(txt) != "" {
|
||||
children["_text"] = txt
|
||||
}
|
||||
return children, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendXMLChildValue(dst map[string]any, key string, value any) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if existing, ok := dst[key]; ok {
|
||||
switch current := existing.(type) {
|
||||
case []any:
|
||||
dst[key] = append(current, value)
|
||||
default:
|
||||
dst[key] = []any{current, value}
|
||||
}
|
||||
return
|
||||
}
|
||||
dst[key] = value
|
||||
}
|
||||
|
||||
func isOnlyRawValue(m map[string]any, raw string) bool {
|
||||
if len(m) != 1 {
|
||||
return false
|
||||
}
|
||||
v, ok := m["_raw"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(v) == strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
type xmlMismatchError struct {
|
||||
want string
|
||||
got string
|
||||
}
|
||||
|
||||
func (e xmlMismatchError) Error() string {
|
||||
return "mismatched xml end tag: want " + e.want + ", got " + e.got
|
||||
}
|
||||
|
||||
func errXMLMismatch(want, got string) error {
|
||||
return xmlMismatchError{want: want, got: got}
|
||||
}
|
||||
Reference in New Issue
Block a user