mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-21 00:17:44 +08:00
refactor: enhance XML tool call parsing to support nested structures, CDATA, and repeated tags
This commit is contained in:
@@ -18,8 +18,6 @@ const (
|
||||
endSentenceMarker = "<|end▁of▁sentence|>"
|
||||
endToolResultsMarker = "<|end▁of▁toolresults|>"
|
||||
endInstructionsMarker = "<|end▁of▁instructions|>"
|
||||
openThinkMarker = "<think>"
|
||||
closeThinkMarker = "</think>"
|
||||
)
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
@@ -55,7 +53,7 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
lastRole = m.Role
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, formatRoleBlock(assistantMarker, closeThinkMarker+m.Text, endSentenceMarker))
|
||||
parts = append(parts, formatRoleBlock(assistantMarker, m.Text, endSentenceMarker))
|
||||
case "tool":
|
||||
if strings.TrimSpace(m.Text) != "" {
|
||||
parts = append(parts, formatRoleBlock(toolMarker, m.Text, endToolResultsMarker))
|
||||
@@ -73,19 +71,15 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
}
|
||||
}
|
||||
if lastRole != "assistant" {
|
||||
thinkPrefix := closeThinkMarker
|
||||
if thinkingEnabled {
|
||||
thinkPrefix = openThinkMarker
|
||||
}
|
||||
parts = append(parts, assistantMarker+thinkPrefix)
|
||||
parts = append(parts, assistantMarker)
|
||||
}
|
||||
out := strings.Join(parts, "")
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
}
|
||||
|
||||
// formatRoleBlock produces a single concatenated block: marker + text + endMarker.
|
||||
// No whitespace is inserted between marker and text to match the official
|
||||
// DeepSeek V3.2 chat template encoding.
|
||||
// No whitespace is inserted between marker and text so role boundaries stay
|
||||
// compact and predictable for downstream parsers.
|
||||
func formatRoleBlock(marker, text, endMarker string) string {
|
||||
out := marker + text
|
||||
if strings.TrimSpace(endMarker) != "" {
|
||||
|
||||
@@ -41,9 +41,12 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
|
||||
if !strings.Contains(got, "<|User|>Question") {
|
||||
t.Fatalf("expected user question, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|Assistant|></think>Answer<|end▁of▁sentence|>") {
|
||||
if !strings.Contains(got, "<|Assistant|>Answer<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected assistant sentence suffix, got %q", got)
|
||||
}
|
||||
if strings.Contains(got, "<think>") || strings.Contains(got, "</think>") {
|
||||
t.Fatalf("did not expect think tags in prompt, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
@@ -55,10 +58,17 @@ func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareWithThinkingEndsWithOpenThink(t *testing.T) {
|
||||
func TestMessagesPrepareWithThinkingIgnoresThinkingFlag(t *testing.T) {
|
||||
messages := []map[string]any{{"role": "user", "content": "Question"}}
|
||||
got := MessagesPrepareWithThinking(messages, true)
|
||||
if !strings.HasSuffix(got, "<|Assistant|><think>") {
|
||||
t.Fatalf("expected thinking suffix, got %q", got)
|
||||
gotThinking := MessagesPrepareWithThinking(messages, true)
|
||||
gotPlain := MessagesPrepareWithThinking(messages, false)
|
||||
if gotThinking != gotPlain {
|
||||
t.Fatalf("expected thinking flag to be ignored, got %q vs %q", gotThinking, gotPlain)
|
||||
}
|
||||
if !strings.HasSuffix(gotThinking, "<|Assistant|>") {
|
||||
t.Fatalf("expected assistant suffix without think tags, got %q", gotThinking)
|
||||
}
|
||||
if strings.Contains(gotThinking, "<think>") || strings.Contains(gotThinking, "</think>") {
|
||||
t.Fatalf("did not expect think tags in prompt, got %q", gotThinking)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package prompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -11,6 +14,8 @@ var promptXMLTextEscaper = strings.NewReplacer(
|
||||
">", ">",
|
||||
)
|
||||
|
||||
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
|
||||
|
||||
// FormatToolCallsForPrompt renders a tool_calls slice into the canonical
|
||||
// prompt-visible history block used across adapters.
|
||||
func FormatToolCallsForPrompt(raw any) string {
|
||||
@@ -87,12 +92,161 @@ func formatToolCallForPrompt(call map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
parameters := formatToolCallParametersForPrompt(argsRaw)
|
||||
|
||||
return " <tool_call>\n" +
|
||||
" <tool_name>" + escapeXMLText(name) + "</tool_name>\n" +
|
||||
" <parameters>" + escapeXMLText(StringifyToolCallArguments(argsRaw)) + "</parameters>\n" +
|
||||
parameters + "\n" +
|
||||
" </tool_call>"
|
||||
}
|
||||
|
||||
func formatToolCallParametersForPrompt(raw any) string {
|
||||
value := normalizePromptToolCallValue(raw)
|
||||
body, ok := renderPromptToolXMLBody(value, " ")
|
||||
if ok {
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return " <parameters></parameters>"
|
||||
}
|
||||
return " <parameters>\n" + body + "\n </parameters>"
|
||||
}
|
||||
|
||||
fallback := StringifyToolCallArguments(raw)
|
||||
if strings.TrimSpace(fallback) == "" {
|
||||
fallback = "{}"
|
||||
}
|
||||
return " <parameters><content>" + renderPromptXMLText(fallback) + "</content></parameters>"
|
||||
}
|
||||
|
||||
func normalizePromptToolCallValue(raw any) any {
|
||||
switch x := raw.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(s), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
return x
|
||||
default:
|
||||
return x
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptToolXMLBody(value any, indent string) (string, bool) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "", true
|
||||
case map[string]any:
|
||||
return renderPromptToolXMLMap(v, indent)
|
||||
case []any:
|
||||
return renderPromptToolXMLArray(v, indent)
|
||||
case string:
|
||||
return indent + "<content>" + renderPromptXMLText(v) + "</content>", true
|
||||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return indent + "<value>" + escapeXMLText(fmt.Sprint(v)) + "</value>", true
|
||||
default:
|
||||
return indent + "<value>" + renderPromptXMLText(fmt.Sprint(v)) + "</value>", true
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptToolXMLMap(m map[string]any, indent string) (string, bool) {
|
||||
if len(m) == 0 {
|
||||
return "", true
|
||||
}
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if !isValidPromptXMLName(k) {
|
||||
return "", false
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
rendered, ok := renderPromptToolXMLNode(key, m[key], indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
}
|
||||
|
||||
func renderPromptToolXMLArray(items []any, indent string) (string, bool) {
|
||||
if len(items) == 0 {
|
||||
return "", true
|
||||
}
|
||||
lines := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
rendered, ok := renderPromptToolXMLNode("item", item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
}
|
||||
|
||||
func renderPromptToolXMLNode(name string, value any, indent string) (string, bool) {
|
||||
if !isValidPromptXMLName(name) {
|
||||
return "", false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
case map[string]any:
|
||||
inner, ok := renderPromptToolXMLMap(v, indent+" ")
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(inner) == "" {
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
}
|
||||
return indent + "<" + name + ">\n" + inner + "\n" + indent + "</" + name + ">", true
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
}
|
||||
lines := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
rendered, ok := renderPromptToolXMLNode(name, item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
case string:
|
||||
return indent + "<" + name + ">" + renderPromptXMLText(v) + "</" + name + ">", true
|
||||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return indent + "<" + name + ">" + escapeXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||||
default:
|
||||
return indent + "<" + name + ">" + renderPromptXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptXMLText(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(text, "]]>") {
|
||||
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
|
||||
}
|
||||
if strings.ContainsAny(text, "<>&\n\r") {
|
||||
return "<![CDATA[" + text + "]]>"
|
||||
}
|
||||
return escapeXMLText(text)
|
||||
}
|
||||
|
||||
func isValidPromptXMLName(name string) bool {
|
||||
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestFormatToolCallsForPromptXML(t *testing.T) {
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty formatted tool calls")
|
||||
}
|
||||
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>{\"query\":\"latest\"}</parameters>\n </tool_call>\n</tool_calls>" {
|
||||
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>\n <query>latest</query>\n </parameters>\n </tool_call>\n</tool_calls>" {
|
||||
t.Fatalf("unexpected formatted tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -34,8 +34,24 @@ func TestFormatToolCallsForPromptEscapesXMLEntities(t *testing.T) {
|
||||
"arguments": `{"q":"a < b && c > d"}`,
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>{\"q\":\"a < b && c > d\"}</parameters>\n </tool_call>\n</tool_calls>"
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>\n <q><![CDATA[a < b && c > d]]></q>\n </parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected escaped tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallsForPromptUsesCDATAForMultilineContent(t *testing.T) {
|
||||
got := FormatToolCallsForPrompt([]any{
|
||||
map[string]any{
|
||||
"name": "write_file",
|
||||
"arguments": map[string]any{
|
||||
"path": "script.sh",
|
||||
"content": "#!/bin/bash\nprintf \"hello\"\n",
|
||||
},
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>write_file</tool_name>\n <parameters>\n <content><![CDATA[#!/bin/bash\nprintf \"hello\"\n]]></content>\n <path>script.sh</path>\n </parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected multiline cdata tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user