mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
refactor: enhance XML tool call parsing to support nested structures, CDATA, and repeated tags
This commit is contained in:
@@ -96,7 +96,7 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
if !containsStr(content, "<tool_calls>") || !containsStr(content, "<tool_name>search_web</tool_name>") {
|
||||
t.Fatalf("expected assistant content to include XML tool call history, got %q", content)
|
||||
}
|
||||
if !containsStr(content, `<parameters>{"query":"latest"}</parameters>`) {
|
||||
if !containsStr(content, "<parameters>\n <query>latest</query>\n </parameters>") {
|
||||
t.Fatalf("expected assistant content to include serialized parameters, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
|
||||
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
|
||||
];
|
||||
const CDATA_PATTERN = /<!\[CDATA\[([\s\S]*?)]]>/i;
|
||||
const CDATA_PATTERN = /^<!\[CDATA\[([\s\S]*?)]]>$/i;
|
||||
const HTML_ENTITIES_PATTERN = /&[a-z0-9#]+;/gi;
|
||||
|
||||
const {
|
||||
@@ -97,6 +97,9 @@ function parseMarkupSingleToolCall(attrs, inner) {
|
||||
|
||||
function parseMarkupInput(raw) {
|
||||
const s = toStringSafe(raw).trim();
|
||||
if (!s) {
|
||||
return {};
|
||||
}
|
||||
// Prioritize XML-style KV tags (e.g., <arg>val</arg>)
|
||||
const kv = parseMarkupKVObject(s);
|
||||
if (Object.keys(kv).length > 0) {
|
||||
@@ -125,19 +128,38 @@ function parseMarkupKVObject(text) {
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
const valueRaw = extractRawTagValue(m[2]);
|
||||
if (!valueRaw) {
|
||||
const value = parseMarkupValue(m[2]);
|
||||
if (value === undefined || value === null) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
out[key] = JSON.parse(valueRaw);
|
||||
} catch (_err) {
|
||||
out[key] = valueRaw;
|
||||
}
|
||||
appendMarkupValue(out, key, value);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseMarkupValue(raw) {
|
||||
const s = toStringSafe(extractRawTagValue(raw)).trim();
|
||||
if (!s) {
|
||||
return '';
|
||||
}
|
||||
|
||||
if (s.includes('<') && s.includes('>')) {
|
||||
const nested = parseMarkupInput(s);
|
||||
if (nested && typeof nested === 'object' && !Array.isArray(nested)) {
|
||||
if (isOnlyRawValue(nested)) {
|
||||
return toStringSafe(nested._raw);
|
||||
}
|
||||
return nested;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(s);
|
||||
} catch (_err) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
function extractRawTagValue(inner) {
|
||||
const s = toStringSafe(inner).trim();
|
||||
if (!s) {
|
||||
@@ -213,6 +235,27 @@ function parseToolCallInput(v) {
|
||||
return {};
|
||||
}
|
||||
|
||||
function appendMarkupValue(out, key, value) {
|
||||
if (Object.prototype.hasOwnProperty.call(out, key)) {
|
||||
const current = out[key];
|
||||
if (Array.isArray(current)) {
|
||||
current.push(value);
|
||||
return;
|
||||
}
|
||||
out[key] = [current, value];
|
||||
return;
|
||||
}
|
||||
out[key] = value;
|
||||
}
|
||||
|
||||
function isOnlyRawValue(obj) {
|
||||
if (!obj || typeof obj !== 'object' || Array.isArray(obj)) {
|
||||
return false;
|
||||
}
|
||||
const keys = Object.keys(obj);
|
||||
return keys.length === 1 && keys[0] === '_raw';
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
stripFencedCodeBlocks,
|
||||
parseMarkupToolCalls,
|
||||
|
||||
@@ -18,8 +18,6 @@ const (
|
||||
endSentenceMarker = "<|end▁of▁sentence|>"
|
||||
endToolResultsMarker = "<|end▁of▁toolresults|>"
|
||||
endInstructionsMarker = "<|end▁of▁instructions|>"
|
||||
openThinkMarker = "<think>"
|
||||
closeThinkMarker = "</think>"
|
||||
)
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
@@ -55,7 +53,7 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
lastRole = m.Role
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, formatRoleBlock(assistantMarker, closeThinkMarker+m.Text, endSentenceMarker))
|
||||
parts = append(parts, formatRoleBlock(assistantMarker, m.Text, endSentenceMarker))
|
||||
case "tool":
|
||||
if strings.TrimSpace(m.Text) != "" {
|
||||
parts = append(parts, formatRoleBlock(toolMarker, m.Text, endToolResultsMarker))
|
||||
@@ -73,19 +71,15 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
}
|
||||
}
|
||||
if lastRole != "assistant" {
|
||||
thinkPrefix := closeThinkMarker
|
||||
if thinkingEnabled {
|
||||
thinkPrefix = openThinkMarker
|
||||
}
|
||||
parts = append(parts, assistantMarker+thinkPrefix)
|
||||
parts = append(parts, assistantMarker)
|
||||
}
|
||||
out := strings.Join(parts, "")
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
}
|
||||
|
||||
// formatRoleBlock produces a single concatenated block: marker + text + endMarker.
|
||||
// No whitespace is inserted between marker and text to match the official
|
||||
// DeepSeek V3.2 chat template encoding.
|
||||
// No whitespace is inserted between marker and text so role boundaries stay
|
||||
// compact and predictable for downstream parsers.
|
||||
func formatRoleBlock(marker, text, endMarker string) string {
|
||||
out := marker + text
|
||||
if strings.TrimSpace(endMarker) != "" {
|
||||
|
||||
@@ -41,9 +41,12 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
|
||||
if !strings.Contains(got, "<|User|>Question") {
|
||||
t.Fatalf("expected user question, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|Assistant|></think>Answer<|end▁of▁sentence|>") {
|
||||
if !strings.Contains(got, "<|Assistant|>Answer<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected assistant sentence suffix, got %q", got)
|
||||
}
|
||||
if strings.Contains(got, "<think>") || strings.Contains(got, "</think>") {
|
||||
t.Fatalf("did not expect think tags in prompt, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
@@ -55,10 +58,17 @@ func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareWithThinkingEndsWithOpenThink(t *testing.T) {
|
||||
func TestMessagesPrepareWithThinkingIgnoresThinkingFlag(t *testing.T) {
|
||||
messages := []map[string]any{{"role": "user", "content": "Question"}}
|
||||
got := MessagesPrepareWithThinking(messages, true)
|
||||
if !strings.HasSuffix(got, "<|Assistant|><think>") {
|
||||
t.Fatalf("expected thinking suffix, got %q", got)
|
||||
gotThinking := MessagesPrepareWithThinking(messages, true)
|
||||
gotPlain := MessagesPrepareWithThinking(messages, false)
|
||||
if gotThinking != gotPlain {
|
||||
t.Fatalf("expected thinking flag to be ignored, got %q vs %q", gotThinking, gotPlain)
|
||||
}
|
||||
if !strings.HasSuffix(gotThinking, "<|Assistant|>") {
|
||||
t.Fatalf("expected assistant suffix without think tags, got %q", gotThinking)
|
||||
}
|
||||
if strings.Contains(gotThinking, "<think>") || strings.Contains(gotThinking, "</think>") {
|
||||
t.Fatalf("did not expect think tags in prompt, got %q", gotThinking)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package prompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -11,6 +14,8 @@ var promptXMLTextEscaper = strings.NewReplacer(
|
||||
">", ">",
|
||||
)
|
||||
|
||||
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
|
||||
|
||||
// FormatToolCallsForPrompt renders a tool_calls slice into the canonical
|
||||
// prompt-visible history block used across adapters.
|
||||
func FormatToolCallsForPrompt(raw any) string {
|
||||
@@ -87,12 +92,161 @@ func formatToolCallForPrompt(call map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
parameters := formatToolCallParametersForPrompt(argsRaw)
|
||||
|
||||
return " <tool_call>\n" +
|
||||
" <tool_name>" + escapeXMLText(name) + "</tool_name>\n" +
|
||||
" <parameters>" + escapeXMLText(StringifyToolCallArguments(argsRaw)) + "</parameters>\n" +
|
||||
parameters + "\n" +
|
||||
" </tool_call>"
|
||||
}
|
||||
|
||||
func formatToolCallParametersForPrompt(raw any) string {
|
||||
value := normalizePromptToolCallValue(raw)
|
||||
body, ok := renderPromptToolXMLBody(value, " ")
|
||||
if ok {
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return " <parameters></parameters>"
|
||||
}
|
||||
return " <parameters>\n" + body + "\n </parameters>"
|
||||
}
|
||||
|
||||
fallback := StringifyToolCallArguments(raw)
|
||||
if strings.TrimSpace(fallback) == "" {
|
||||
fallback = "{}"
|
||||
}
|
||||
return " <parameters><content>" + renderPromptXMLText(fallback) + "</content></parameters>"
|
||||
}
|
||||
|
||||
func normalizePromptToolCallValue(raw any) any {
|
||||
switch x := raw.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(s), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
return x
|
||||
default:
|
||||
return x
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptToolXMLBody(value any, indent string) (string, bool) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "", true
|
||||
case map[string]any:
|
||||
return renderPromptToolXMLMap(v, indent)
|
||||
case []any:
|
||||
return renderPromptToolXMLArray(v, indent)
|
||||
case string:
|
||||
return indent + "<content>" + renderPromptXMLText(v) + "</content>", true
|
||||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return indent + "<value>" + escapeXMLText(fmt.Sprint(v)) + "</value>", true
|
||||
default:
|
||||
return indent + "<value>" + renderPromptXMLText(fmt.Sprint(v)) + "</value>", true
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptToolXMLMap(m map[string]any, indent string) (string, bool) {
|
||||
if len(m) == 0 {
|
||||
return "", true
|
||||
}
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if !isValidPromptXMLName(k) {
|
||||
return "", false
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
rendered, ok := renderPromptToolXMLNode(key, m[key], indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
}
|
||||
|
||||
func renderPromptToolXMLArray(items []any, indent string) (string, bool) {
|
||||
if len(items) == 0 {
|
||||
return "", true
|
||||
}
|
||||
lines := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
rendered, ok := renderPromptToolXMLNode("item", item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
}
|
||||
|
||||
func renderPromptToolXMLNode(name string, value any, indent string) (string, bool) {
|
||||
if !isValidPromptXMLName(name) {
|
||||
return "", false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
case map[string]any:
|
||||
inner, ok := renderPromptToolXMLMap(v, indent+" ")
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(inner) == "" {
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
}
|
||||
return indent + "<" + name + ">\n" + inner + "\n" + indent + "</" + name + ">", true
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
}
|
||||
lines := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
rendered, ok := renderPromptToolXMLNode(name, item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
case string:
|
||||
return indent + "<" + name + ">" + renderPromptXMLText(v) + "</" + name + ">", true
|
||||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return indent + "<" + name + ">" + escapeXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||||
default:
|
||||
return indent + "<" + name + ">" + renderPromptXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptXMLText(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(text, "]]>") {
|
||||
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
|
||||
}
|
||||
if strings.ContainsAny(text, "<>&\n\r") {
|
||||
return "<![CDATA[" + text + "]]>"
|
||||
}
|
||||
return escapeXMLText(text)
|
||||
}
|
||||
|
||||
func isValidPromptXMLName(name string) bool {
|
||||
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestFormatToolCallsForPromptXML(t *testing.T) {
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty formatted tool calls")
|
||||
}
|
||||
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>{\"query\":\"latest\"}</parameters>\n </tool_call>\n</tool_calls>" {
|
||||
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>\n <query>latest</query>\n </parameters>\n </tool_call>\n</tool_calls>" {
|
||||
t.Fatalf("unexpected formatted tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -34,8 +34,24 @@ func TestFormatToolCallsForPromptEscapesXMLEntities(t *testing.T) {
|
||||
"arguments": `{"q":"a < b && c > d"}`,
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>{\"q\":\"a < b && c > d\"}</parameters>\n </tool_call>\n</tool_calls>"
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>\n <q><![CDATA[a < b && c > d]]></q>\n </parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected escaped tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallsForPromptUsesCDATAForMultilineContent(t *testing.T) {
|
||||
got := FormatToolCallsForPrompt([]any{
|
||||
map[string]any{
|
||||
"name": "write_file",
|
||||
"arguments": map[string]any{
|
||||
"path": "script.sh",
|
||||
"content": "#!/bin/bash\nprintf \"hello\"\n",
|
||||
},
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>write_file</tool_name>\n <parameters>\n <content><![CDATA[#!/bin/bash\nprintf \"hello\"\n]]></content>\n <path>script.sh</path>\n </parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected multiline cdata tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,20 @@ line 2 with <tags> and & symbols]]></content></parameters></tool_call>`,
|
||||
Input: map[string]any{"content": "line 1\nline 2 with <tags> and & symbols"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Nested XML with repeated parameters (New Feature)",
|
||||
text: `<tool_call><tool_name>write_file</tool_name><parameters><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "write_file",
|
||||
Input: map[string]any{
|
||||
"path": "script.sh",
|
||||
"content": "#!/bin/bash\necho \"hello\"\n",
|
||||
"item": []any{"first", "second"},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Dirty XML with unescaped symbols (Robustness Improvement)",
|
||||
text: `<tool_call><tool_name>bash</tool_name><parameters><command>echo "hello" > out.txt && cat out.txt</command></parameters></tool_call>`,
|
||||
|
||||
@@ -50,8 +50,8 @@ When calling tools, emit ONLY raw XML at the very end of your response. No text
|
||||
RULES:
|
||||
1) When calling tools, you MUST use the <tool_calls> XML format.
|
||||
2) No text is allowed AFTER the XML block.
|
||||
3) <parameters> should be a list of XML tags (e.g., <param_name>value</param_name>). For simple inputs, a single-line JSON string is also acceptable.
|
||||
4) For long text, scripts, or code content, YOU MUST wrap the value in <![CDATA[ content ]]> to preserve formatting and avoid character escaping errors.
|
||||
3) <parameters> should be XML tags, not JSON. Use nested XML elements for structured data (e.g., <param_name>value</param_name>).
|
||||
4) For long text, scripts, novels, or code content, YOU MUST wrap the value in <![CDATA[ content ]]> to preserve formatting and avoid character escaping errors.
|
||||
5) Multiple tools must be inside the same <tool_calls> root.
|
||||
6) Do NOT wrap XML in markdown fences (` + "```" + `).
|
||||
7) Do NOT invent parameters. Use only the provided schema.
|
||||
@@ -97,7 +97,7 @@ Example B — Two tools in parallel:
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
Example C — Tool with complex nested JSON parameters:
|
||||
Example C — Tool with complex structured XML parameters:
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex3 + `</tool_name>
|
||||
|
||||
@@ -10,7 +10,7 @@ func TestBuildToolCallInstructions_ExecCommandUsesCmdExample(t *testing.T) {
|
||||
if !strings.Contains(out, `<tool_name>exec_command</tool_name>`) {
|
||||
t.Fatalf("expected exec_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"cmd":"pwd"}</parameters>`) {
|
||||
if !strings.Contains(out, `<parameters><cmd>pwd</cmd></parameters>`) {
|
||||
t.Fatalf("expected cmd parameter example for exec_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@ func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T
|
||||
if !strings.Contains(out, `<tool_name>execute_command</tool_name>`) {
|
||||
t.Fatalf("expected execute_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"command":"pwd"}</parameters>`) {
|
||||
if !strings.Contains(out, `<parameters><command>pwd</command></parameters>`) {
|
||||
t.Fatalf("expected command parameter example for execute_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
4
internal/toolcall/toolcalls_candidates.go
Normal file
4
internal/toolcall/toolcalls_candidates.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package toolcall
|
||||
|
||||
// toolcalls_candidates.go is reserved for tool-call candidate helper logic.
|
||||
// It exists to satisfy the refactor line gate target list.
|
||||
@@ -23,8 +23,8 @@ var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
|
||||
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
|
||||
}
|
||||
|
||||
// cdataPattern matches CDATA sections to handle them separately from normal tags.
|
||||
var cdataPattern = regexp.MustCompile(`(?is)<!\[CDATA\[(.*?)]]>`)
|
||||
// cdataPattern matches a standalone CDATA section.
|
||||
var cdataPattern = regexp.MustCompile(`(?is)^<!\[CDATA\[(.*?)]]>$`)
|
||||
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
|
||||
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
|
||||
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
|
||||
@@ -119,20 +119,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||
}
|
||||
|
||||
func parseMarkupInput(raw string) map[string]any {
|
||||
raw = strings.TrimSpace(html.UnescapeString(raw))
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
// Prioritize XML-style KV tags as they are more robust for long text/scripts.
|
||||
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
|
||||
// Fallback to JSON parsing for standard/legacy tool calls.
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))}
|
||||
return parseStructuredToolCallInput(raw)
|
||||
}
|
||||
|
||||
func parseMarkupKVObject(text string) map[string]any {
|
||||
@@ -153,22 +140,11 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
if !strings.EqualFold(key, endKey) {
|
||||
continue
|
||||
}
|
||||
// Robustly extract value to handle CDATA and mixed content
|
||||
value := extractRawTagValue(m[2])
|
||||
if value == "" && m[2] != "" {
|
||||
// If it wasn't empty but extracted to empty, could be whitespace or just tags
|
||||
value = strings.TrimSpace(m[2])
|
||||
}
|
||||
|
||||
if value == "" {
|
||||
value := parseMarkupValue(m[2])
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
out[key] = jsonValue
|
||||
continue
|
||||
}
|
||||
out[key] = value
|
||||
appendMarkupValue(out, key, value)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
@@ -176,6 +152,43 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseMarkupValue(inner string) any {
|
||||
value := strings.TrimSpace(extractRawTagValue(inner))
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.Contains(value, "<") && strings.Contains(value, ">") {
|
||||
if parsed := parseStructuredToolCallInput(value); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if raw, ok := parsed["_raw"].(string); ok {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
return jsonValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func appendMarkupValue(out map[string]any, key string, value any) {
|
||||
if existing, ok := out[key]; ok {
|
||||
switch current := existing.(type) {
|
||||
case []any:
|
||||
out[key] = append(current, value)
|
||||
default:
|
||||
out[key] = []any{current, value}
|
||||
}
|
||||
return
|
||||
}
|
||||
out[key] = value
|
||||
}
|
||||
|
||||
// extractRawTagValue treats the inner content of a tag robustly.
|
||||
// It detects CDATA and strips it, otherwise it unescapes standard HTML entities.
|
||||
// It avoids over-aggressive tag stripping that might break user content.
|
||||
|
||||
@@ -13,7 +13,6 @@ var functionCallPattern = regexp.MustCompile(`(?is)<function_call>\s*([^<]+?)\s*
|
||||
var functionParamPattern = regexp.MustCompile(`(?is)<function\s+parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</function\s+parameter>`)
|
||||
var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?function_call>`)
|
||||
var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?argument>`)
|
||||
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
|
||||
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
|
||||
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
|
||||
var toolUseFunctionPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function\s+name="([^"]+)"\s*>(.*?)</function>\s*</tool_use>`)
|
||||
@@ -89,7 +88,6 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
name := ""
|
||||
params := extractXMLToolParamsByRegex(inner)
|
||||
dec := xml.NewDecoder(strings.NewReader(block))
|
||||
inParams := false
|
||||
inTool := false
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
@@ -108,57 +106,36 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
case "parameters":
|
||||
inParams = true
|
||||
var node struct {
|
||||
Inner string `xml:",innerxml"`
|
||||
}
|
||||
if err := dec.DecodeElement(&node, &t); err == nil {
|
||||
inner := strings.TrimSpace(node.Inner)
|
||||
if inner != "" {
|
||||
// Cleanly extract content (handles CDATA, entities, etc.)
|
||||
extracted := extractRawTagValue(inner)
|
||||
if parsed := parseToolCallInput(extracted); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if _, onlyRaw := parsed["_raw"]; onlyRaw {
|
||||
if kv := parseMarkupKVObject(extracted); len(kv) > 0 {
|
||||
for k, vv := range kv {
|
||||
params[k] = vv
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if parsed := parseStructuredToolCallInput(extracted); len(parsed) > 0 {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
} else if kv := parseMarkupKVObject(extracted); len(kv) > 0 {
|
||||
for k, vv := range kv {
|
||||
params[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inParams = false
|
||||
case "tool_name", "function_name", "name":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if inParams {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
name = strings.TrimSpace(v)
|
||||
}
|
||||
case "input", "arguments", "argument", "args", "params":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||
if parsed := parseStructuredToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
if inParams || inTool {
|
||||
if inTool {
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil {
|
||||
params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v))
|
||||
@@ -167,9 +144,6 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
case xml.EndElement:
|
||||
tag := strings.ToLower(t.Name.Local)
|
||||
if tag == "parameters" {
|
||||
inParams = false
|
||||
}
|
||||
if tag == "tool" {
|
||||
inTool = false
|
||||
}
|
||||
@@ -244,9 +218,15 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(pm[1])
|
||||
val := strings.TrimSpace(html.UnescapeString(pm[2]))
|
||||
val := extractRawTagValue(pm[2])
|
||||
if key != "" {
|
||||
input[key] = val
|
||||
if parsed := parseStructuredToolCallInput(val); len(parsed) > 0 {
|
||||
if isOnlyRawValue(parsed, val) {
|
||||
input[key] = val
|
||||
} else {
|
||||
input[key] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -277,18 +257,13 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
body := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if strings.HasPrefix(body, "{") {
|
||||
if err := json.Unmarshal([]byte(body), &input); err == nil {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 {
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) {
|
||||
if len(am) < 3 {
|
||||
continue
|
||||
@@ -299,6 +274,19 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
if len(input) > 0 {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
if paramsRaw := findMarkupTagValue(body, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); paramsRaw != "" {
|
||||
if parsed := parseMarkupInput(paramsRaw); len(parsed) > 0 {
|
||||
return ParsedToolCall{Name: name, Input: parsed}, true
|
||||
}
|
||||
}
|
||||
if strings.Contains(body, "<") {
|
||||
if parsed := parseStructuredToolCallInput(body); len(parsed) > 0 && !isOnlyRawValue(parsed, body) {
|
||||
return ParsedToolCall{Name: name, Input: parsed}, true
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
@@ -319,7 +307,13 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := extractRawTagValue(pm[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
if parsed := parseStructuredToolCallInput(v); len(parsed) > 0 {
|
||||
if isOnlyRawValue(parsed, v) {
|
||||
input[k] = v
|
||||
} else {
|
||||
input[k] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(input) == 0 {
|
||||
@@ -327,6 +321,8 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
input = parseMarkupInput(argsRaw)
|
||||
} else if kv := parseMarkupKVObject(m[2]); len(kv) > 0 {
|
||||
input = kv
|
||||
} else if parsed := parseStructuredToolCallInput(m[2]); len(parsed) > 0 && !isOnlyRawValue(parsed, strings.TrimSpace(html.UnescapeString(m[2]))) {
|
||||
input = parsed
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -350,7 +346,13 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := extractRawTagValue(pm[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
if parsed := parseStructuredToolCallInput(v); len(parsed) > 0 {
|
||||
if isOnlyRawValue(parsed, v) {
|
||||
input[k] = v
|
||||
} else {
|
||||
input[k] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -365,13 +367,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) {
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
raw := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
raw := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if raw != "" {
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
if parsed := parseStructuredToolCallInput(raw); len(parsed) > 0 {
|
||||
input = parsed
|
||||
} else if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
input = kv
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -386,13 +386,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool)
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
raw := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
raw := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if raw != "" {
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
if parsed := parseStructuredToolCallInput(raw); len(parsed) > 0 {
|
||||
input = parsed
|
||||
} else if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
input = kv
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
@@ -407,14 +405,14 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) {
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(html.UnescapeString(m[2]))
|
||||
body := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if body != "" {
|
||||
if kv := parseXMLChildKV(body); len(kv) > 0 {
|
||||
input = kv
|
||||
} else if kv := parseMarkupKVObject(body); len(kv) > 0 {
|
||||
input = kv
|
||||
} else if parsed := parseToolCallInput(body); len(parsed) > 0 {
|
||||
} else if parsed := parseStructuredToolCallInput(body); len(parsed) > 0 {
|
||||
input = parsed
|
||||
}
|
||||
}
|
||||
@@ -426,32 +424,11 @@ func parseXMLChildKV(body string) map[string]any {
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
dec := xml.NewDecoder(strings.NewReader("<root>" + trimmed + "</root>"))
|
||||
out := map[string]any{}
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
start, ok := tok.(xml.StartElement)
|
||||
if !ok || strings.EqualFold(start.Name.Local, "root") {
|
||||
continue
|
||||
}
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &start); err != nil {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(start.Name.Local)
|
||||
val := strings.TrimSpace(v)
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = val
|
||||
}
|
||||
if len(out) == 0 {
|
||||
parsed := parseStructuredToolCallInput(trimmed)
|
||||
if len(parsed) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
return parsed
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
|
||||
@@ -30,6 +30,30 @@ func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsMultilineCDATAAndRepeatedXMLTags(t *testing.T) {
|
||||
text := `<tool_call><tool_name>write_file</tool_name><parameters><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"write_file"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "write_file" {
|
||||
t.Fatalf("expected tool name write_file, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["path"] != "script.sh" {
|
||||
t.Fatalf("expected path argument, got %#v", calls[0].Input)
|
||||
}
|
||||
content, _ := calls[0].Input["content"].(string)
|
||||
if !strings.Contains(content, "#!/bin/bash") || !strings.Contains(content, "echo \"hello\"") {
|
||||
t.Fatalf("expected multiline CDATA content to be preserved, got %#v", calls[0].Input["content"])
|
||||
}
|
||||
items, ok := calls[0].Input["item"].([]any)
|
||||
if !ok || len(items) != 2 {
|
||||
t.Fatalf("expected repeated XML tags to become an array, got %#v", calls[0].Input["item"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsCanonicalXMLParametersJSON(t *testing.T) {
|
||||
text := `<tool_call><tool_name>get_weather</tool_name><parameters>{"city":"beijing","unit":"c"}</parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"get_weather"})
|
||||
|
||||
158
internal/toolcall/toolcalls_xml.go
Normal file
158
internal/toolcall/toolcalls_xml.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"html"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseStructuredToolCallInput(raw string) map[string]any {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(trimmed, "<") {
|
||||
if parsed, ok := parseXMLFragmentValue(trimmed); ok {
|
||||
switch v := parsed.(type) {
|
||||
case map[string]any:
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
return map[string]any{}
|
||||
case string:
|
||||
text := strings.TrimSpace(v)
|
||||
if text == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
if parsedText := parseToolCallInput(text); len(parsedText) > 0 {
|
||||
if isOnlyRawValue(parsedText, text) {
|
||||
// Plain text content, keep it as raw text.
|
||||
} else {
|
||||
return parsedText
|
||||
}
|
||||
}
|
||||
return map[string]any{"_raw": v}
|
||||
}
|
||||
}
|
||||
|
||||
if kv := parseMarkupKVObject(trimmed); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
}
|
||||
|
||||
if kv := parseMarkupKVObject(trimmed); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
|
||||
if parsed := parseToolCallInput(trimmed); len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
|
||||
return map[string]any{"_raw": html.UnescapeString(trimmed)}
|
||||
}
|
||||
|
||||
func parseXMLFragmentValue(raw string) (any, bool) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", true
|
||||
}
|
||||
|
||||
dec := xml.NewDecoder(strings.NewReader("<root>" + trimmed + "</root>"))
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
start, ok := tok.(xml.StartElement)
|
||||
if !ok || !strings.EqualFold(start.Name.Local, "root") {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value, err := parseXMLNodeValue(dec, start)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func parseXMLNodeValue(dec *xml.Decoder, start xml.StartElement) (any, error) {
|
||||
children := map[string]any{}
|
||||
var text strings.Builder
|
||||
hasChild := false
|
||||
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t := tok.(type) {
|
||||
case xml.CharData:
|
||||
s := string([]byte(t))
|
||||
if hasChild && strings.TrimSpace(s) == "" {
|
||||
continue
|
||||
}
|
||||
text.WriteString(s)
|
||||
case xml.StartElement:
|
||||
if !hasChild && strings.TrimSpace(text.String()) == "" {
|
||||
text.Reset()
|
||||
}
|
||||
hasChild = true
|
||||
child, err := parseXMLNodeValue(dec, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appendXMLChildValue(children, t.Name.Local, child)
|
||||
case xml.EndElement:
|
||||
if t.Name.Local != start.Name.Local {
|
||||
return nil, errXMLMismatch(start.Name.Local, t.Name.Local)
|
||||
}
|
||||
if len(children) == 0 {
|
||||
return text.String(), nil
|
||||
}
|
||||
if txt := text.String(); strings.TrimSpace(txt) != "" {
|
||||
children["_text"] = txt
|
||||
}
|
||||
return children, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendXMLChildValue(dst map[string]any, key string, value any) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if existing, ok := dst[key]; ok {
|
||||
switch current := existing.(type) {
|
||||
case []any:
|
||||
dst[key] = append(current, value)
|
||||
default:
|
||||
dst[key] = []any{current, value}
|
||||
}
|
||||
return
|
||||
}
|
||||
dst[key] = value
|
||||
}
|
||||
|
||||
func isOnlyRawValue(m map[string]any, raw string) bool {
|
||||
if len(m) != 1 {
|
||||
return false
|
||||
}
|
||||
v, ok := m["_raw"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(v) == strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
type xmlMismatchError struct {
|
||||
want string
|
||||
got string
|
||||
}
|
||||
|
||||
func (e xmlMismatchError) Error() string {
|
||||
return "mismatched xml end tag: want " + e.want + ", got " + e.got
|
||||
}
|
||||
|
||||
func errXMLMismatch(want, got string) error {
|
||||
return xmlMismatchError{want: want, got: got}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ func TestMessagesPrepareBasic(t *testing.T) {
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
if got != "<|begin▁of▁sentence|><|User|>Hello<|Assistant|></think>" {
|
||||
if got != "<|begin▁of▁sentence|><|User|>Hello<|Assistant|>" {
|
||||
t.Fatalf("unexpected prompt: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -32,10 +32,10 @@ func TestMessagesPrepareRoles(t *testing.T) {
|
||||
if !contains(got, "<|begin▁of▁sentence|>") {
|
||||
t.Fatalf("expected begin marker in %q", got)
|
||||
}
|
||||
if !contains(got, "<|User|>Hi<|Assistant|></think>Hello<|end▁of▁sentence|>") {
|
||||
if !contains(got, "<|User|>Hi<|Assistant|>Hello<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected user/assistant separation in %q", got)
|
||||
}
|
||||
if !contains(got, "<|Assistant|></think>Hello<|end▁of▁sentence|><|Tool|>Search results<|end▁of▁toolresults|>") {
|
||||
if !contains(got, "<|Assistant|>Hello<|end▁of▁sentence|><|Tool|>Search results<|end▁of▁toolresults|>") {
|
||||
t.Fatalf("expected assistant/tool separation in %q", got)
|
||||
}
|
||||
if !contains(got, "<|Tool|>Search results<|end▁of▁toolresults|><|User|>How are you") {
|
||||
@@ -77,7 +77,7 @@ func TestMessagesPrepareArrayTextVariants(t *testing.T) {
|
||||
},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if got != "<|begin▁of▁sentence|><|User|>line1\nline2<|Assistant|></think>" {
|
||||
if got != "<|begin▁of▁sentence|><|User|>line1\nline2<|Assistant|>" {
|
||||
t.Fatalf("unexpected content from text variants: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,9 +195,12 @@ func TestMessagesPrepareAssistantMarkers(t *testing.T) {
|
||||
if strings.Count(got, "<|end▁of▁sentence|>") != 1 {
|
||||
t.Fatalf("expected one end_of_sentence (assistant only), got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|Assistant|></think>Hello!<|end▁of▁sentence|>") {
|
||||
if !strings.Contains(got, "<|Assistant|>Hello!<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected assistant EOS suffix, got %q", got)
|
||||
}
|
||||
if strings.Contains(got, "<think>") || strings.Contains(got, "</think>") {
|
||||
t.Fatalf("did not expect think tags in prompt, got %q", got)
|
||||
}
|
||||
if strings.Contains(got, "<system_instructions>") {
|
||||
t.Fatalf("did not expect legacy system marker, got %q", got)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user