refactor: enhance XML tool call parsing to support nested structures, CDATA, and repeated tags

This commit is contained in:
CJACK
2026-04-19 19:58:45 +08:00
parent 26d195f2a6
commit 0f2b5fee23
16 changed files with 550 additions and 140 deletions

View File

@@ -18,8 +18,6 @@ const (
endSentenceMarker = "<end▁of▁sentence>"
endToolResultsMarker = "<end▁of▁toolresults>"
endInstructionsMarker = "<end▁of▁instructions>"
openThinkMarker = "<think>"
closeThinkMarker = "</think>"
)
func MessagesPrepare(messages []map[string]any) string {
@@ -55,7 +53,7 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
lastRole = m.Role
switch m.Role {
case "assistant":
parts = append(parts, formatRoleBlock(assistantMarker, closeThinkMarker+m.Text, endSentenceMarker))
parts = append(parts, formatRoleBlock(assistantMarker, m.Text, endSentenceMarker))
case "tool":
if strings.TrimSpace(m.Text) != "" {
parts = append(parts, formatRoleBlock(toolMarker, m.Text, endToolResultsMarker))
@@ -73,19 +71,15 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
}
}
if lastRole != "assistant" {
thinkPrefix := closeThinkMarker
if thinkingEnabled {
thinkPrefix = openThinkMarker
}
parts = append(parts, assistantMarker+thinkPrefix)
parts = append(parts, assistantMarker)
}
out := strings.Join(parts, "")
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
}
// formatRoleBlock produces a single concatenated block: marker + text + endMarker.
// No whitespace is inserted between marker and text to match the official
// DeepSeek V3.2 chat template encoding.
// No whitespace is inserted between marker and text so role boundaries stay
// compact and predictable for downstream parsers.
func formatRoleBlock(marker, text, endMarker string) string {
out := marker + text
if strings.TrimSpace(endMarker) != "" {

View File

@@ -41,9 +41,12 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
if !strings.Contains(got, "<User>Question") {
t.Fatalf("expected user question, got %q", got)
}
if !strings.Contains(got, "<Assistant></think>Answer<end▁of▁sentence>") {
if !strings.Contains(got, "<Assistant>Answer<end▁of▁sentence>") {
t.Fatalf("expected assistant sentence suffix, got %q", got)
}
if strings.Contains(got, "<think>") || strings.Contains(got, "</think>") {
t.Fatalf("did not expect think tags in prompt, got %q", got)
}
}
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
@@ -55,10 +58,17 @@ func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
}
}
func TestMessagesPrepareWithThinkingEndsWithOpenThink(t *testing.T) {
func TestMessagesPrepareWithThinkingIgnoresThinkingFlag(t *testing.T) {
messages := []map[string]any{{"role": "user", "content": "Question"}}
got := MessagesPrepareWithThinking(messages, true)
if !strings.HasSuffix(got, "<Assistant><think>") {
t.Fatalf("expected thinking suffix, got %q", got)
gotThinking := MessagesPrepareWithThinking(messages, true)
gotPlain := MessagesPrepareWithThinking(messages, false)
if gotThinking != gotPlain {
t.Fatalf("expected thinking flag to be ignored, got %q vs %q", gotThinking, gotPlain)
}
if !strings.HasSuffix(gotThinking, "<Assistant>") {
t.Fatalf("expected assistant suffix without think tags, got %q", gotThinking)
}
if strings.Contains(gotThinking, "<think>") || strings.Contains(gotThinking, "</think>") {
t.Fatalf("did not expect think tags in prompt, got %q", gotThinking)
}
}

View File

@@ -2,6 +2,9 @@ package prompt
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
@@ -11,6 +14,8 @@ var promptXMLTextEscaper = strings.NewReplacer(
">", "&gt;",
)
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
// FormatToolCallsForPrompt renders a tool_calls slice into the canonical
// prompt-visible history block used across adapters.
func FormatToolCallsForPrompt(raw any) string {
@@ -87,12 +92,161 @@ func formatToolCallForPrompt(call map[string]any) string {
}
}
parameters := formatToolCallParametersForPrompt(argsRaw)
return " <tool_call>\n" +
" <tool_name>" + escapeXMLText(name) + "</tool_name>\n" +
" <parameters>" + escapeXMLText(StringifyToolCallArguments(argsRaw)) + "</parameters>\n" +
parameters + "\n" +
" </tool_call>"
}
func formatToolCallParametersForPrompt(raw any) string {
value := normalizePromptToolCallValue(raw)
body, ok := renderPromptToolXMLBody(value, " ")
if ok {
if strings.TrimSpace(body) == "" {
return " <parameters></parameters>"
}
return " <parameters>\n" + body + "\n </parameters>"
}
fallback := StringifyToolCallArguments(raw)
if strings.TrimSpace(fallback) == "" {
fallback = "{}"
}
return " <parameters><content>" + renderPromptXMLText(fallback) + "</content></parameters>"
}
func normalizePromptToolCallValue(raw any) any {
switch x := raw.(type) {
case nil:
return nil
case string:
s := strings.TrimSpace(x)
if s == "" {
return ""
}
var parsed any
if err := json.Unmarshal([]byte(s), &parsed); err == nil {
return parsed
}
return x
default:
return x
}
}
func renderPromptToolXMLBody(value any, indent string) (string, bool) {
switch v := value.(type) {
case nil:
return "", true
case map[string]any:
return renderPromptToolXMLMap(v, indent)
case []any:
return renderPromptToolXMLArray(v, indent)
case string:
return indent + "<content>" + renderPromptXMLText(v) + "</content>", true
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return indent + "<value>" + escapeXMLText(fmt.Sprint(v)) + "</value>", true
default:
return indent + "<value>" + renderPromptXMLText(fmt.Sprint(v)) + "</value>", true
}
}
func renderPromptToolXMLMap(m map[string]any, indent string) (string, bool) {
if len(m) == 0 {
return "", true
}
keys := make([]string, 0, len(m))
for k := range m {
if !isValidPromptXMLName(k) {
return "", false
}
keys = append(keys, k)
}
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, key := range keys {
rendered, ok := renderPromptToolXMLNode(key, m[key], indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
}
func renderPromptToolXMLArray(items []any, indent string) (string, bool) {
if len(items) == 0 {
return "", true
}
lines := make([]string, 0, len(items))
for _, item := range items {
rendered, ok := renderPromptToolXMLNode("item", item, indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
}
func renderPromptToolXMLNode(name string, value any, indent string) (string, bool) {
if !isValidPromptXMLName(name) {
return "", false
}
switch v := value.(type) {
case nil:
return indent + "<" + name + "></" + name + ">", true
case map[string]any:
inner, ok := renderPromptToolXMLMap(v, indent+" ")
if !ok {
return "", false
}
if strings.TrimSpace(inner) == "" {
return indent + "<" + name + "></" + name + ">", true
}
return indent + "<" + name + ">\n" + inner + "\n" + indent + "</" + name + ">", true
case []any:
if len(v) == 0 {
return indent + "<" + name + "></" + name + ">", true
}
lines := make([]string, 0, len(v))
for _, item := range v {
rendered, ok := renderPromptToolXMLNode(name, item, indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
case string:
return indent + "<" + name + ">" + renderPromptXMLText(v) + "</" + name + ">", true
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return indent + "<" + name + ">" + escapeXMLText(fmt.Sprint(v)) + "</" + name + ">", true
default:
return indent + "<" + name + ">" + renderPromptXMLText(fmt.Sprint(v)) + "</" + name + ">", true
}
}
func renderPromptXMLText(text string) string {
if text == "" {
return ""
}
if strings.Contains(text, "]]>") {
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
}
if strings.ContainsAny(text, "<>&\n\r") {
return "<![CDATA[" + text + "]]>"
}
return escapeXMLText(text)
}
func isValidPromptXMLName(name string) bool {
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
}
func normalizeToolArgumentString(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {

View File

@@ -22,7 +22,7 @@ func TestFormatToolCallsForPromptXML(t *testing.T) {
if got == "" {
t.Fatal("expected non-empty formatted tool calls")
}
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>{\"query\":\"latest\"}</parameters>\n </tool_call>\n</tool_calls>" {
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>\n <query>latest</query>\n </parameters>\n </tool_call>\n</tool_calls>" {
t.Fatalf("unexpected formatted tool call XML: %q", got)
}
}
@@ -34,8 +34,24 @@ func TestFormatToolCallsForPromptEscapesXMLEntities(t *testing.T) {
"arguments": `{"q":"a < b && c > d"}`,
},
})
want := "<tool_calls>\n <tool_call>\n <tool_name>search&lt;&amp;&gt;</tool_name>\n <parameters>{\"q\":\"a &lt; b &amp;&amp; c &gt; d\"}</parameters>\n </tool_call>\n</tool_calls>"
want := "<tool_calls>\n <tool_call>\n <tool_name>search&lt;&amp;&gt;</tool_name>\n <parameters>\n <q><![CDATA[a < b && c > d]]></q>\n </parameters>\n </tool_call>\n</tool_calls>"
if got != want {
t.Fatalf("unexpected escaped tool call XML: %q", got)
}
}
func TestFormatToolCallsForPromptUsesCDATAForMultilineContent(t *testing.T) {
got := FormatToolCallsForPrompt([]any{
map[string]any{
"name": "write_file",
"arguments": map[string]any{
"path": "script.sh",
"content": "#!/bin/bash\nprintf \"hello\"\n",
},
},
})
want := "<tool_calls>\n <tool_call>\n <tool_name>write_file</tool_name>\n <parameters>\n <content><![CDATA[#!/bin/bash\nprintf \"hello\"\n]]></content>\n <path>script.sh</path>\n </parameters>\n </tool_call>\n</tool_calls>"
if got != want {
t.Fatalf("unexpected multiline cdata tool call XML: %q", got)
}
}