mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-10 19:27:41 +08:00
fix: align tool call protocol and thinking controls
This commit is contained in:
@@ -52,6 +52,7 @@ func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store C
|
||||
}
|
||||
}
|
||||
translatedReq := translatorcliproxy.ToOpenAI(sdktranslator.FormatClaude, translateModel, raw, stream)
|
||||
translatedReq = applyExplicitThinkingOverrideToOpenAIRequest(translatedReq, req)
|
||||
|
||||
isVercelPrepare := strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
isVercelRelease := strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
@@ -123,6 +124,27 @@ func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store C
|
||||
return true
|
||||
}
|
||||
|
||||
func applyExplicitThinkingOverrideToOpenAIRequest(translated []byte, original map[string]any) []byte {
|
||||
enabled, ok := util.ResolveThinkingOverride(original)
|
||||
if !ok {
|
||||
return translated
|
||||
}
|
||||
req := map[string]any{}
|
||||
if err := json.Unmarshal(translated, &req); err != nil {
|
||||
return translated
|
||||
}
|
||||
typ := "disabled"
|
||||
if enabled {
|
||||
typ = "enabled"
|
||||
}
|
||||
req["thinking"] = map[string]any{"type": typ}
|
||||
out, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return translated
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
|
||||
@@ -251,14 +251,14 @@ func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.
|
||||
payload string
|
||||
wantToolUse bool
|
||||
}{
|
||||
{name: "canonical_tools_wrapper", payload: `<tools><tool_call><tool_name>Bash</tool_name><param><command>pwd</command></param></tool_call></tools>`, wantToolUse: true},
|
||||
{name: "invoke_parameter_wrapper", payload: `<tool_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></tool_calls>`, wantToolUse: true},
|
||||
{name: "legacy_single_tool_root", payload: `<tool><tool_name>Bash</tool_name><param><command>pwd</command></param></tool>`, wantToolUse: false},
|
||||
{name: "legacy_tool_call_json", payload: `<tool>{"tool":"Bash","params":{"command":"pwd"}}</tool>`, wantToolUse: false},
|
||||
{name: "legacy_nested_tool_tag_style", payload: `<tool><tool name="Bash"><command>pwd</command></tool_call></tool>`, wantToolUse: false},
|
||||
{name: "legacy_function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`, wantToolUse: false},
|
||||
{name: "legacy_antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`, wantToolUse: false},
|
||||
{name: "legacy_antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`, wantToolUse: false},
|
||||
{name: "legacy_invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`, wantToolUse: false},
|
||||
{name: "legacy_function_calls_wrapper", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`, wantToolUse: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -291,7 +291,7 @@ func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.
|
||||
|
||||
func TestHandleClaudeStreamRealtimeDetectsToolUseWithLeadingProse(t *testing.T) {
|
||||
h := &Handler{}
|
||||
payload := "I'll call a tool now.\\n<tools><tool_call><tool_name>write_file</tool_name><param>{\\\"path\\\":\\\"/tmp/a.txt\\\",\\\"content\\\":\\\"abc\\\"}</param></tool_call></tools>"
|
||||
payload := "I'll call a tool now.\\n<tool_calls><invoke name=\\\"write_file\\\"><parameter name=\\\"path\\\">/tmp/a.txt</parameter><parameter name=\\\"content\\\">abc</parameter></invoke></tool_calls>"
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+payload+`"}`,
|
||||
`data: [DONE]`,
|
||||
|
||||
@@ -93,10 +93,10 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "<tools>") || !containsStr(content, "<tool_name>search_web</tool_name>") {
|
||||
if !containsStr(content, "<tool_calls>") || !containsStr(content, `<invoke name="search_web">`) {
|
||||
t.Fatalf("expected assistant content to include XML tool call history, got %q", content)
|
||||
}
|
||||
if !containsStr(content, "<param>\n <query><![CDATA[latest]]></query>\n </param>") {
|
||||
if !containsStr(content, `<parameter name="query"><![CDATA[latest]]></parameter>`) {
|
||||
t.Fatalf("expected assistant content to include serialized parameters, got %q", content)
|
||||
}
|
||||
}
|
||||
@@ -292,7 +292,7 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
if !containsStr(prompt, "Search the web") {
|
||||
t.Fatalf("expected description in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "<tools>") {
|
||||
if !containsStr(prompt, "<tool_calls>") {
|
||||
t.Fatalf("expected XML tool_calls format in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "TOOL CALL FORMAT") {
|
||||
|
||||
@@ -106,6 +106,26 @@ func TestClaudeProxyViaOpenAIUsesGlobalAliasMapping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIPreservesThinkingOverride(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"disabled"},"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
thinking, _ := openAI.seenReq["thinking"].(map[string]any)
|
||||
if thinking["type"] != "disabled" {
|
||||
t.Fatalf("expected translated OpenAI request to preserve disabled thinking, got %#v", openAI.seenReq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyTranslatesInlineImageToOpenAIDataURL(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{OpenAI: openAI}
|
||||
|
||||
@@ -217,8 +217,8 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置文本\n<tools>\n <tool_call>\n <tool_name>read_file</tool_name>\n <param>{\"path\":\"README.MD\"}</param>\n </tool_call>\n</tools>"}`,
|
||||
`data: {"p":"response/content","v":"中间文本\n<tools>\n <tool_call>\n <tool_name>search</tool_name>\n <param>{\"q\":\"golang\"}</param>\n </tool_call>\n</tools>"}`,
|
||||
`data: {"p":"response/content","v":"前置文本\n<tool_calls>\n <invoke name=\"read_file\">\n <parameter name=\"path\">README.MD</parameter>\n </invoke>\n</tool_calls>"}`,
|
||||
`data: {"p":"response/content","v":"中间文本\n<tool_calls>\n <invoke name=\"search\">\n <parameter name=\"q\">golang</parameter>\n </invoke>\n</tool_calls>"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -12,9 +12,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
historySplitFilename = "IGNORE"
|
||||
historySplitContentType = "text/plain; charset=utf-8"
|
||||
historySplitPurpose = "assistants"
|
||||
historySplitFilename = "HISTORY.txt"
|
||||
historySplitInjectedFilename = "IGNORE"
|
||||
historySplitContentType = "text/plain; charset=utf-8"
|
||||
historySplitPurpose = "assistants"
|
||||
)
|
||||
|
||||
func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, stdReq util.StandardRequest) (util.StandardRequest, error) {
|
||||
@@ -114,7 +115,7 @@ func buildOpenAIHistoryTranscript(messages []any) string {
|
||||
if transcript == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("[file content end]\n\n%s\n\n[file name]: %s\n[file content begin]\n", transcript, historySplitFilename)
|
||||
return fmt.Sprintf("[file content end]\n\n%s\n\n[file name]: %s\n[file content begin]\n", transcript, historySplitInjectedFilename)
|
||||
}
|
||||
|
||||
func prependUniqueRefFileID(existing []string, fileID string) []string {
|
||||
|
||||
@@ -76,7 +76,7 @@ func TestBuildOpenAIHistoryTranscriptUsesInjectedFileWrapper(t *testing.T) {
|
||||
if !strings.Contains(transcript, "[reasoning_content]") || !strings.Contains(transcript, "hidden reasoning") {
|
||||
t.Fatalf("expected reasoning block preserved, got %q", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "<tools>") {
|
||||
if !strings.Contains(transcript, "<tool_calls>") {
|
||||
t.Fatalf("expected tool calls preserved, got %q", transcript)
|
||||
}
|
||||
if !strings.HasSuffix(transcript, "\n[file name]: IGNORE\n[file content begin]\n") {
|
||||
@@ -180,7 +180,7 @@ func TestApplyHistorySplitCarriesHistoryText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsHistorySplitUploadsIgnoreFileAndKeepsLatestPrompt(t *testing.T) {
|
||||
func TestChatCompletionsHistorySplitUploadsHistoryFileAndKeepsLatestPrompt(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
@@ -210,7 +210,7 @@ func TestChatCompletionsHistorySplitUploadsIgnoreFileAndKeepsLatestPrompt(t *tes
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
upload := ds.uploadCalls[0]
|
||||
if upload.Filename != "IGNORE" {
|
||||
if upload.Filename != "HISTORY.txt" {
|
||||
t.Fatalf("unexpected upload filename: %q", upload.Filename)
|
||||
}
|
||||
if upload.Purpose != "assistants" {
|
||||
|
||||
@@ -38,10 +38,10 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
|
||||
t.Fatalf("expected 4 normalized messages with assistant tool history preserved, got %d", len(normalized))
|
||||
}
|
||||
assistantContent, _ := normalized[2]["content"].(string)
|
||||
if !strings.Contains(assistantContent, "<tools>") {
|
||||
if !strings.Contains(assistantContent, "<tool_calls>") {
|
||||
t.Fatalf("assistant tool history should be preserved in XML form, got %q", assistantContent)
|
||||
}
|
||||
if !strings.Contains(assistantContent, "<tool_name>get_weather</tool_name>") {
|
||||
if !strings.Contains(assistantContent, `<invoke name="get_weather">`) {
|
||||
t.Fatalf("expected tool name in preserved history, got %q", assistantContent)
|
||||
}
|
||||
if !strings.Contains(normalized[3]["content"].(string), `"temp":18`) {
|
||||
@@ -49,7 +49,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
|
||||
}
|
||||
|
||||
prompt := util.MessagesPrepare(normalized)
|
||||
if !strings.Contains(prompt, "<tools>") {
|
||||
if !strings.Contains(prompt, "<tool_calls>") {
|
||||
t.Fatalf("expected preserved assistant tool history in prompt: %q", prompt)
|
||||
}
|
||||
}
|
||||
@@ -177,10 +177,10 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
||||
t.Fatalf("expected assistant tool_call-only message preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Count(content, "<tool_call>") != 2 {
|
||||
if strings.Count(content, "<invoke name=") != 2 {
|
||||
t.Fatalf("expected two preserved tool call blocks, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "<tool_name>search_web</tool_name>") || !strings.Contains(content, "<tool_name>eval_javascript</tool_name>") {
|
||||
if !strings.Contains(content, `<invoke name="search_web">`) || !strings.Contains(content, `<invoke name="eval_javascript">`) {
|
||||
t.Fatalf("expected both tool names in preserved history, got %q", content)
|
||||
}
|
||||
}
|
||||
@@ -258,7 +258,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLi
|
||||
if strings.Contains(content, "null") {
|
||||
t.Fatalf("expected no null literal injection, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "<tools>") {
|
||||
if !strings.Contains(content, "<tool_calls>") {
|
||||
t.Fatalf("expected assistant tool history in normalized content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,10 +47,10 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
|
||||
if !strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt should preserve tool output content: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "<tools>") {
|
||||
if !strings.Contains(finalPrompt, "<tool_calls>") {
|
||||
t.Fatalf("handler finalPrompt should preserve assistant tool history: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "<tool_name>get_weather</tool_name>") {
|
||||
if !strings.Contains(finalPrompt, `<invoke name="get_weather">`) {
|
||||
t.Fatalf("handler finalPrompt should include tool name history: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false)
|
||||
if !strings.Contains(finalPrompt, "Remember: The ONLY valid way to use tools is the <tools>...</tools> XML block at the end of your response.") {
|
||||
if !strings.Contains(finalPrompt, "Remember: The ONLY valid way to use tools is the <tool_calls>...</tool_calls> XML block at the end of your response.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final tool-call anchor instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") {
|
||||
|
||||
@@ -122,8 +122,8 @@ func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("前置文本\n<tools>\n <tool_call>\n <tool_name>read_file</tool_name>\n <param>{\"path\":\"README.MD\"}</param>\n </tool_call>\n</tools>") +
|
||||
sseLine("中间文本\n<tools>\n <tool_call>\n <tool_name>search</tool_name>\n <param>{\"q\":\"golang\"}</param>\n </tool_call>\n</tools>") +
|
||||
streamBody := sseLine("前置文本\n<tool_calls>\n <invoke name=\"read_file\">\n <parameter name=\"path\">README.MD</parameter>\n </invoke>\n</tool_calls>") +
|
||||
sseLine("中间文本\n<tool_calls>\n <invoke name=\"search\">\n <parameter name=\"q\">golang</parameter>\n </invoke>\n</tool_calls>") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
|
||||
@@ -136,6 +136,22 @@ func TestNormalizeOpenAIResponsesRequestThinkingExtraBodyFallback(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestReasoningDisablesThinking(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"reasoning": map[string]any{"effort": "none"},
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.Thinking {
|
||||
t.Fatalf("expected reasoning.effort=none to disable thinking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
|
||||
@@ -9,42 +9,27 @@ import (
|
||||
// --- XML tool call support for the streaming sieve ---
|
||||
|
||||
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
|
||||
var xmlToolCallClosingTags = []string{"</tools>", "</tool_call>",
|
||||
// Agent-style XML tags (Roo Code, Cline, etc.)
|
||||
"</attempt_completion>", "</ask_followup_question>", "</new_task>", "</result>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tools", "<tool_call",
|
||||
// Agent-style XML tags
|
||||
"<attempt_completion", "<ask_followup_question", "<new_task", "<result"}
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls"}
|
||||
|
||||
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
|
||||
// Order matters: longer/wrapper tags must be checked first.
|
||||
var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
{"<tools", "</tools>"},
|
||||
{"<tool_call", "</tool_call>"},
|
||||
// Agent-style: these are XML "tool call" patterns from coding agents.
|
||||
// They get captured → parsed. If parsing fails, the raw XML is preserved
|
||||
// so the caller can still see the original text.
|
||||
{"<attempt_completion", "</attempt_completion>"},
|
||||
{"<ask_followup_question", "</ask_followup_question>"},
|
||||
{"<new_task", "</new_task>"},
|
||||
{"<tool_calls", "</tool_calls>"},
|
||||
}
|
||||
|
||||
// xmlToolCallBlockPattern matches a complete XML tool call block (wrapper or standalone).
|
||||
// xmlToolCallBlockPattern matches a complete canonical XML tool call block.
|
||||
//
|
||||
//nolint:unused // reserved for future fast-path XML block detection.
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tools\b[^>]*>\s*(?:.*?)\s*</tools>|<tool_call\b[^>]*>(?:.*?)</tool_call>|<attempt_completion>(?:.*?)</attempt_completion>|<ask_followup_question>(?:.*?)</ask_followup_question>|<new_task>(?:.*?)</new_task>)`)
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
|
||||
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{"<tools>", "<tools\n", "<tools ", "<tool_call>", "<tool_call\n", "<tool_call ",
|
||||
// Agent-style tags
|
||||
"<attempt_completion>", "<ask_followup_question>", "<new_task>"}
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls "}
|
||||
|
||||
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
|
||||
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
lower := strings.ToLower(captured)
|
||||
// Find the FIRST matching open/close pair, preferring wrapper tags.
|
||||
// Tag pairs are ordered longest-first (e.g. <tool_calls before <tool_call)
|
||||
// so wrapper tags are checked before inner tags.
|
||||
// Find the FIRST matching open/close pair for the canonical wrapper.
|
||||
for _, pair := range xmlToolCallTagPairs {
|
||||
openIdx := strings.Index(lower, pair.open)
|
||||
if openIdx < 0 {
|
||||
@@ -54,8 +39,7 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
closeIdx := strings.LastIndex(lower, pair.close)
|
||||
if closeIdx < openIdx {
|
||||
// Opening tag is present but its specific closing tag hasn't arrived.
|
||||
// Return not-ready so we keep buffering — do NOT fall through to
|
||||
// try inner pairs (e.g. <tool_call inside <tool_calls).
|
||||
// Return not-ready so we keep buffering until the canonical wrapper closes.
|
||||
return "", nil, "", false
|
||||
}
|
||||
closeEnd := closeIdx + len(pair.close)
|
||||
@@ -88,8 +72,8 @@ func hasOpenXMLToolTag(captured string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// findPartialXMLToolTagStart checks if the string ends with a partial XML tool tag
|
||||
// (e.g., "<tool_ca" or "<inv") and returns the position of the '<'.
|
||||
// findPartialXMLToolTagStart checks if the string ends with a partial canonical
|
||||
// XML wrapper tag (e.g., "<too") and returns the position of the '<'.
|
||||
func findPartialXMLToolTagStart(s string) int {
|
||||
lastLT := strings.LastIndex(s, "<")
|
||||
if lastLT < 0 {
|
||||
|
||||
@@ -9,12 +9,11 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate a model producing XML tool call output chunk by chunk.
|
||||
chunks := []string{
|
||||
"<tools>\n",
|
||||
" <tool_call>\n",
|
||||
" <tool_name>read_file</tool_name>\n",
|
||||
` <param>{"path":"README.MD"}</param>` + "\n",
|
||||
" </tool_call>\n",
|
||||
"</tools>",
|
||||
"<tool_calls>\n",
|
||||
` <invoke name="read_file">` + "\n",
|
||||
` <parameter name="path">README.MD</parameter>` + "\n",
|
||||
" </invoke>\n",
|
||||
"</tool_calls>",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
@@ -31,7 +30,7 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
if strings.Contains(textContent, "<invoke ") {
|
||||
t.Fatalf("XML tool call content leaked to text: %q", textContent)
|
||||
}
|
||||
if strings.Contains(textContent, "read_file") {
|
||||
@@ -48,10 +47,10 @@ func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
|
||||
payload := strings.Repeat("x", 4096)
|
||||
splitAt := len(payload) / 2
|
||||
chunks := []string{
|
||||
"<tools>\n <tool_call>\n <tool_name>" + toolName + "</tool_name>\n <param>\n <content><![CDATA[",
|
||||
"<tool_calls>\n <invoke name=\"" + toolName + "\">\n <parameter name=\"content\"><![CDATA[",
|
||||
payload[:splitAt],
|
||||
payload[splitAt:],
|
||||
"]]></content>\n </param>\n </tool_call>\n</tools>",
|
||||
"]]></parameter>\n </invoke>\n</tool_calls>",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
@@ -90,8 +89,8 @@ func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
// Model outputs some prose then an XML tool call.
|
||||
chunks := []string{
|
||||
"Let me check the file.\n",
|
||||
"<tools>\n <tool_call>\n <tool_name>read_file</tool_name>\n",
|
||||
` <param>{"path":"go.mod"}</param>` + "\n </tool_call>\n</tools>",
|
||||
"<tool_calls>\n <invoke name=\"read_file\">\n",
|
||||
` <parameter name="path">go.mod</parameter>` + "\n </invoke>\n</tool_calls>",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
@@ -113,7 +112,7 @@ func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
t.Fatalf("expected leading text to be emitted, got %q", textContent)
|
||||
}
|
||||
// The XML itself should NOT leak.
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
if strings.Contains(textContent, "<invoke ") {
|
||||
t.Fatalf("XML tool call content leaked to text: %q", textContent)
|
||||
}
|
||||
if toolCalls == 0 {
|
||||
@@ -143,7 +142,7 @@ func TestProcessToolSievePassesThroughNonToolXMLBlock(t *testing.T) {
|
||||
|
||||
func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<tool><title>plain xml</title></tool><tools><tool_call><tool_name>read_file</tool_name><param>{"path":"README.MD"}</param></tool_call></tools>`
|
||||
chunk := `<tool><title>plain xml</title></tool><tool_calls><invoke name="read_file"><parameter name="path">README.MD</parameter></invoke></tool_calls>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
@@ -156,8 +155,8 @@ func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||
if !strings.Contains(textContent.String(), `<tool><title>plain xml</title></tool>`) {
|
||||
t.Fatalf("expected leading non-tool XML to be preserved, got %q", textContent.String())
|
||||
}
|
||||
if strings.Contains(textContent.String(), `<tools><tool_call>`) {
|
||||
t.Fatalf("expected canonical tool XML to be intercepted, got %q", textContent.String())
|
||||
if strings.Contains(textContent.String(), `<tool_calls><invoke`) {
|
||||
t.Fatalf("expected invoke tool XML to be intercepted, got %q", textContent.String())
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected exactly one parsed tool call from suffix, got %d events=%#v", toolCalls, events)
|
||||
@@ -166,7 +165,7 @@ func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||
|
||||
func TestProcessToolSievePassesThroughMalformedExecutableXMLBlock(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<tools><tool_call><param>{"path":"README.md"}</param></tool_call></tools>`
|
||||
chunk := `<tool_calls><invoke name="read_file"><param>{"path":"README.md"}</param></invoke></tool_calls>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
@@ -189,17 +188,17 @@ func TestProcessToolSievePassesThroughFencedXMLToolCallExamples(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
input := strings.Join([]string{
|
||||
"Before first example.\n```",
|
||||
"xml\n<tools><tool_call><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\n",
|
||||
"xml\n<tool_calls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\n",
|
||||
"Between examples.\n```xml\n",
|
||||
"<tools><tool_call><tool_name>search</tool_name><param>{\"q\":\"golang\"}</param></tool_call></tools>\n",
|
||||
"<tool_calls><invoke name=\"search\"><parameter name=\"q\">golang</parameter></invoke></tool_calls>\n",
|
||||
"```\nAfter examples.",
|
||||
}, "")
|
||||
|
||||
chunks := []string{
|
||||
"Before first example.\n```",
|
||||
"xml\n<tools><tool_call><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\n",
|
||||
"xml\n<tool_calls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\n",
|
||||
"Between examples.\n```xml\n",
|
||||
"<tools><tool_call><tool_name>search</tool_name><param>{\"q\":\"golang\"}</param></tool_call></tools>\n",
|
||||
"<tool_calls><invoke name=\"search\"><parameter name=\"q\">golang</parameter></invoke></tool_calls>\n",
|
||||
"```\nAfter examples.",
|
||||
}
|
||||
|
||||
@@ -230,13 +229,13 @@ func TestProcessToolSieveKeepsPartialXMLTagInsideFencedExample(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
input := strings.Join([]string{
|
||||
"Example:\n```xml\n<tool_ca",
|
||||
"ll><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\n",
|
||||
"lls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\n",
|
||||
"Done.",
|
||||
}, "")
|
||||
|
||||
chunks := []string{
|
||||
"Example:\n```xml\n<tool_ca",
|
||||
"ll><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\n",
|
||||
"lls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\n",
|
||||
"Done.",
|
||||
}
|
||||
|
||||
@@ -266,15 +265,15 @@ func TestProcessToolSieveKeepsPartialXMLTagInsideFencedExample(t *testing.T) {
|
||||
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Chunk ends with a partial XML tool tag.
|
||||
events := processToolSieveChunk(&state, "Hello <tool_ca", []string{"read_file"})
|
||||
events := processToolSieveChunk(&state, "Hello <too", []string{"read_file"})
|
||||
|
||||
var textContent string
|
||||
for _, evt := range events {
|
||||
textContent += evt.Content
|
||||
}
|
||||
|
||||
// "Hello " should be emitted, but "<tool_ca" should be held back.
|
||||
if strings.Contains(textContent, "<tool_ca") {
|
||||
// "Hello " should be emitted, but "<too" should be held back.
|
||||
if strings.Contains(textContent, "<too") {
|
||||
t.Fatalf("partial XML tag should not be emitted, got %q", textContent)
|
||||
}
|
||||
if !strings.Contains(textContent, "Hello") {
|
||||
@@ -288,9 +287,9 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"tools_tag", "some text <tools>\n", 10},
|
||||
{"tool_call_tag", "prefix <tool_call>\n", 7},
|
||||
{"xml_inside_code_fence", "```xml\n<tools><tool_call><tool_name>read_file</tool_name></tool_call></tools>\n```", -1},
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"bare_tool_call_text", "prefix <tool_call>\n", -1},
|
||||
{"xml_inside_code_fence", "```xml\n<tool_calls><invoke name=\"read_file\"></invoke></tool_calls>\n```", -1},
|
||||
{"no_xml", "just plain text", -1},
|
||||
{"gemini_json_no_detect", `some text {"functionCall":{"name":"search"}}`, -1},
|
||||
}
|
||||
@@ -310,10 +309,10 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"partial_tools", "Hello <too", 6},
|
||||
{"partial_tool_call", "Hello <tool_ca", 6},
|
||||
{"partial_tool_calls", "Hello <tool_ca", 6},
|
||||
{"bare_tool_call_not_held", "Hello <tool_name", -1},
|
||||
{"partial_lt_only", "Text <", 5},
|
||||
{"complete_tag", "Text <tools>done", -1},
|
||||
{"complete_tag", "Text <tool_calls>done", -1},
|
||||
{"no_lt", "plain text", -1},
|
||||
{"closed_lt", "a < b > c", -1},
|
||||
}
|
||||
@@ -328,10 +327,10 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHasOpenXMLToolTag(t *testing.T) {
|
||||
if !hasOpenXMLToolTag("<tools>\n<tool_call>\n<tool_name>foo</tool_name>") {
|
||||
if !hasOpenXMLToolTag("<tool_calls>\n<invoke name=\"foo\">") {
|
||||
t.Fatal("should detect open XML tool tag without closing tag")
|
||||
}
|
||||
if hasOpenXMLToolTag("<tools>\n<tool_call>\n<tool_name>foo</tool_name></tool_call>\n</tools>") {
|
||||
if hasOpenXMLToolTag("<tool_calls>\n<invoke name=\"foo\"></invoke>\n</tool_calls>") {
|
||||
t.Fatal("should return false when closing tag is present")
|
||||
}
|
||||
if hasOpenXMLToolTag("plain text without any XML") {
|
||||
@@ -340,44 +339,29 @@ func TestHasOpenXMLToolTag(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test the EXACT scenario the user reports: token-by-token streaming where
|
||||
// <tools> tag arrives in small pieces.
|
||||
// <tool_calls> tag arrives in small pieces.
|
||||
func TestProcessToolSieveTokenByTokenXMLNoLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate DeepSeek model generating tokens one at a time.
|
||||
chunks := []string{
|
||||
"<",
|
||||
"tool",
|
||||
"s",
|
||||
"_ca",
|
||||
"lls",
|
||||
">\n",
|
||||
" <",
|
||||
"tool",
|
||||
"_call",
|
||||
">\n",
|
||||
" <",
|
||||
"tool",
|
||||
"_name",
|
||||
">",
|
||||
" <in",
|
||||
"voke",
|
||||
` name="`,
|
||||
"read",
|
||||
"_file",
|
||||
`">` + "\n",
|
||||
" <para",
|
||||
`meter name="path">`,
|
||||
"README.MD",
|
||||
"</parameter>\n",
|
||||
" </invoke>\n",
|
||||
"</",
|
||||
"tool",
|
||||
"_name",
|
||||
">\n",
|
||||
" <",
|
||||
"param",
|
||||
">",
|
||||
`{"path"`,
|
||||
`: "README.MD"`,
|
||||
`}`,
|
||||
"</",
|
||||
"param",
|
||||
">\n",
|
||||
" </",
|
||||
"tool",
|
||||
"_call",
|
||||
">\n",
|
||||
"</",
|
||||
"tools",
|
||||
"tool_calls",
|
||||
">",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
@@ -395,10 +379,10 @@ func TestProcessToolSieveTokenByTokenXMLNoLeak(t *testing.T) {
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
if strings.Contains(textContent, "<invoke ") {
|
||||
t.Fatalf("XML tool call content leaked to text in token-by-token mode: %q", textContent)
|
||||
}
|
||||
if strings.Contains(textContent, "tools>") {
|
||||
if strings.Contains(textContent, "tool_calls>") {
|
||||
t.Fatalf("closing tag fragment leaked to text: %q", textContent)
|
||||
}
|
||||
if strings.Contains(textContent, "read_file") {
|
||||
@@ -414,9 +398,8 @@ func TestFlushToolSieveIncompleteXMLFallsBackToText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// XML block starts but stream ends before completion.
|
||||
chunks := []string{
|
||||
"<tools>\n",
|
||||
" <tool_call>\n",
|
||||
" <tool_name>read_file</tool_name>\n",
|
||||
"<tool_calls>\n",
|
||||
" <invoke name=\"read_file\">\n",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
@@ -437,19 +420,19 @@ func TestFlushToolSieveIncompleteXMLFallsBackToText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the opening tag "<tools>\n " is NOT emitted as text content.
|
||||
// Test that the opening tag "<tool_calls>\n " is NOT emitted as text content.
|
||||
func TestOpeningXMLTagNotLeakedAsContent(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// First chunk is the opening tag - should be held, not emitted.
|
||||
evts1 := processToolSieveChunk(&state, "<tools>\n ", []string{"read_file"})
|
||||
evts1 := processToolSieveChunk(&state, "<tool_calls>\n ", []string{"read_file"})
|
||||
for _, evt := range evts1 {
|
||||
if strings.Contains(evt.Content, "<tools>") {
|
||||
if strings.Contains(evt.Content, "<tool_calls>") {
|
||||
t.Fatalf("opening tag leaked on first chunk: %q", evt.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Remaining content arrives.
|
||||
evts2 := processToolSieveChunk(&state, "<tool_call>\n <tool_name>read_file</tool_name>\n <param>{\"path\":\"README.MD\"}</param>\n </tool_call>\n</tools>", []string{"read_file"})
|
||||
evts2 := processToolSieveChunk(&state, "<invoke name=\"read_file\">\n <parameter name=\"path\">README.MD</parameter>\n </invoke>\n</tool_calls>", []string{"read_file"})
|
||||
evts2 = append(evts2, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
@@ -462,7 +445,7 @@ func TestOpeningXMLTagNotLeakedAsContent(t *testing.T) {
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
if strings.Contains(textContent, "<invoke ") {
|
||||
t.Fatalf("XML content leaked: %q", textContent)
|
||||
}
|
||||
if toolCalls == 0 {
|
||||
@@ -501,3 +484,24 @@ func TestProcessToolSieveFallsBackToRawAttemptCompletion(t *testing.T) {
|
||||
t.Fatalf("expected agent XML to fall back to raw text, got %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<invoke name="read_file"><parameter name="path">README.md</parameter></invoke>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected bare invoke to remain text, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != chunk {
|
||||
t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ const {
|
||||
stripFencedCodeBlocks,
|
||||
} = require('./parse_payload');
|
||||
|
||||
const TOOL_MARKUP_PREFIXES = ['<tools', '<tool_call'];
|
||||
const TOOL_MARKUP_PREFIXES = ['<tool_calls'];
|
||||
|
||||
function extractToolNames(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
'use strict';
|
||||
|
||||
const TOOLS_WRAPPER_PATTERN = /<tools\b[^>]*>([\s\S]*?)<\/tools>/gi;
|
||||
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?tool_call\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?tool_call>/gi;
|
||||
const TOOL_CALL_CANONICAL_BODY_PATTERN = /^\s*<(?:[a-z0-9_:-]+:)?tool_name\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?tool_name>\s*<(?:[a-z0-9_:-]+:)?param\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?param>\s*$/i;
|
||||
const TOOLS_WRAPPER_PATTERN = /<tool_calls\b[^>]*>([\s\S]*?)<\/tool_calls>/gi;
|
||||
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?invoke>/gi;
|
||||
const PARAMETER_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?parameter\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/gi;
|
||||
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||
const CDATA_PATTERN = /^<!\[CDATA\[([\s\S]*?)]]>$/i;
|
||||
const XML_ATTR_PATTERN = /\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')/gi;
|
||||
|
||||
const {
|
||||
toStringSafe,
|
||||
@@ -27,7 +28,7 @@ function parseMarkupToolCalls(text) {
|
||||
for (const wrapper of raw.matchAll(TOOLS_WRAPPER_PATTERN)) {
|
||||
const body = toStringSafe(wrapper[1]);
|
||||
for (const block of body.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
|
||||
const parsed = parseMarkupSingleToolCall(toStringSafe(block[1]).trim());
|
||||
const parsed = parseMarkupSingleToolCall(block);
|
||||
if (parsed) {
|
||||
out.push(parsed);
|
||||
}
|
||||
@@ -36,33 +37,43 @@ function parseMarkupToolCalls(text) {
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseMarkupSingleToolCall(inner) {
|
||||
// Try inline JSON parse for the inner content.
|
||||
function parseMarkupSingleToolCall(block) {
|
||||
const attrs = parseTagAttributes(block[1]);
|
||||
const name = toStringSafe(attrs.name).trim();
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
const inner = toStringSafe(block[2]).trim();
|
||||
|
||||
if (inner) {
|
||||
try {
|
||||
const decoded = JSON.parse(inner);
|
||||
if (decoded && typeof decoded === 'object' && !Array.isArray(decoded) && decoded.name) {
|
||||
if (decoded && typeof decoded === 'object' && !Array.isArray(decoded)) {
|
||||
return {
|
||||
name: toStringSafe(decoded.name),
|
||||
input: decoded.input && typeof decoded.input === 'object' && !Array.isArray(decoded.input) ? decoded.input : {},
|
||||
name,
|
||||
input: decoded.input && typeof decoded.input === 'object' && !Array.isArray(decoded.input)
|
||||
? decoded.input
|
||||
: decoded.parameters && typeof decoded.parameters === 'object' && !Array.isArray(decoded.parameters)
|
||||
? decoded.parameters
|
||||
: {},
|
||||
};
|
||||
}
|
||||
} catch (_err) {
|
||||
// Not JSON, continue with markup parsing.
|
||||
}
|
||||
}
|
||||
|
||||
const match = inner.match(TOOL_CALL_CANONICAL_BODY_PATTERN);
|
||||
if (!match || match.length < 3) {
|
||||
const input = {};
|
||||
for (const match of inner.matchAll(PARAMETER_BLOCK_PATTERN)) {
|
||||
const parameterAttrs = parseTagAttributes(match[1]);
|
||||
const paramName = toStringSafe(parameterAttrs.name).trim();
|
||||
if (!paramName) {
|
||||
continue;
|
||||
}
|
||||
appendMarkupValue(input, paramName, parseMarkupValue(match[2]));
|
||||
}
|
||||
if (Object.keys(input).length === 0 && inner.trim() !== '') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const name = extractRawTagValue(match[1]).trim();
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const input = parseMarkupInput(match[2]);
|
||||
return { name, input };
|
||||
}
|
||||
|
||||
@@ -124,11 +135,14 @@ function parseMarkupValue(raw) {
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(s);
|
||||
} catch (_err) {
|
||||
return s;
|
||||
if (s.startsWith('{') || s.startsWith('[')) {
|
||||
try {
|
||||
return JSON.parse(s);
|
||||
} catch (_err) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
function extractRawTagValue(inner) {
|
||||
@@ -158,6 +172,22 @@ function unescapeHtml(safe) {
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
function parseTagAttributes(raw) {
|
||||
const source = toStringSafe(raw);
|
||||
const out = {};
|
||||
if (!source) {
|
||||
return out;
|
||||
}
|
||||
for (const match of source.matchAll(XML_ATTR_PATTERN)) {
|
||||
const key = toStringSafe(match[1]).trim().toLowerCase();
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
out[key] = match[3] || match[4] || '';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallInput(v) {
|
||||
if (v == null) {
|
||||
return {};
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
'use strict';
|
||||
const { parseToolCalls } = require('./parse');
|
||||
|
||||
// Tag pairs ordered longest-first: wrapper tags checked before inner tags.
|
||||
// XML wrapper tag pair used by the streaming sieve.
|
||||
const XML_TOOL_TAG_PAIRS = [
|
||||
{ open: '<tools', close: '</tools>' },
|
||||
{ open: '<tool_call', close: '</tool_call>' },
|
||||
{ open: '<tool_calls', close: '</tool_calls>' },
|
||||
];
|
||||
|
||||
const XML_TOOL_OPENING_TAGS = XML_TOOL_TAG_PAIRS.map(p => p.open);
|
||||
|
||||
function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
|
||||
const lower = captured.toLowerCase();
|
||||
// Find the FIRST matching open/close pair, preferring wrapper tags.
|
||||
// Find the FIRST matching open/close pair for the canonical wrapper.
|
||||
for (const pair of XML_TOOL_TAG_PAIRS) {
|
||||
const openIdx = lower.indexOf(pair.open);
|
||||
if (openIdx < 0) {
|
||||
@@ -21,7 +20,7 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
|
||||
const closeIdx = lower.lastIndexOf(pair.close);
|
||||
if (closeIdx < openIdx) {
|
||||
// Opening tag present but specific closing tag hasn't arrived.
|
||||
// Return not-ready — do NOT fall through to inner pairs.
|
||||
// Return not-ready so buffering continues until the wrapper closes.
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const closeEnd = closeIdx + pair.close.length;
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
'use strict';
|
||||
|
||||
const XML_TOOL_SEGMENT_TAGS = [
|
||||
'<tools>', '<tools\n', '<tools ', '<tool_call>', '<tool_call\n', '<tool_call ',
|
||||
'<tool_calls>', '<tool_calls\n', '<tool_calls ',
|
||||
];
|
||||
|
||||
const XML_TOOL_OPENING_TAGS = [
|
||||
'<tools', '<tool_call',
|
||||
'<tool_calls',
|
||||
];
|
||||
|
||||
const XML_TOOL_CLOSING_TAGS = [
|
||||
'</tools>', '</tool_call>',
|
||||
'</tool_calls>',
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -16,8 +16,8 @@ var promptXMLTextEscaper = strings.NewReplacer(
|
||||
|
||||
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
|
||||
|
||||
// FormatToolCallsForPrompt renders a tool_calls slice into the canonical
|
||||
// prompt-visible history block used across adapters.
|
||||
// FormatToolCallsForPrompt renders a tool_calls slice into the prompt-visible
|
||||
// invoke/parameter history block used across adapters.
|
||||
func FormatToolCallsForPrompt(raw any) string {
|
||||
calls, ok := raw.([]any)
|
||||
if !ok || len(calls) == 0 {
|
||||
@@ -38,7 +38,7 @@ func FormatToolCallsForPrompt(raw any) string {
|
||||
if len(blocks) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "<tools>\n" + strings.Join(blocks, "\n") + "\n</tools>"
|
||||
return "<tool_calls>\n" + strings.Join(blocks, "\n") + "\n</tool_calls>"
|
||||
}
|
||||
|
||||
// StringifyToolCallArguments normalizes tool arguments into a compact string
|
||||
@@ -93,28 +93,99 @@ func formatToolCallForPrompt(call map[string]any) string {
|
||||
}
|
||||
|
||||
parameters := formatToolCallParametersForPrompt(argsRaw)
|
||||
if parameters == "" {
|
||||
return ` <invoke name="` + escapeXMLAttribute(name) + `"></invoke>`
|
||||
}
|
||||
|
||||
return " <tool_call>\n" +
|
||||
" <tool_name>" + escapeXMLText(name) + "</tool_name>\n" +
|
||||
return " <invoke name=\"" + escapeXMLAttribute(name) + "\">\n" +
|
||||
parameters + "\n" +
|
||||
" </tool_call>"
|
||||
" </invoke>"
|
||||
}
|
||||
|
||||
func formatToolCallParametersForPrompt(raw any) string {
|
||||
value := normalizePromptToolCallValue(raw)
|
||||
body, ok := renderPromptToolXMLBody(value, " ")
|
||||
if ok {
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return " <param></param>"
|
||||
}
|
||||
return " <param>\n" + body + "\n </param>"
|
||||
body, ok := renderPromptToolParameters(value, " ")
|
||||
if ok && strings.TrimSpace(body) != "" {
|
||||
return body
|
||||
}
|
||||
|
||||
fallback := StringifyToolCallArguments(raw)
|
||||
if strings.TrimSpace(fallback) == "" {
|
||||
fallback = "{}"
|
||||
return ""
|
||||
}
|
||||
return " <parameter name=\"content\">" + renderPromptXMLText(fallback) + "</parameter>"
|
||||
}
|
||||
|
||||
func renderPromptToolParameters(value any, indent string) (string, bool) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "", true
|
||||
case map[string]any:
|
||||
if len(v) == 0 {
|
||||
return "", true
|
||||
}
|
||||
keys := make([]string, 0, len(v))
|
||||
for k := range v {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
rendered, ok := renderPromptParameterNode(key, v[key], indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
case []any:
|
||||
lines := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
rendered, ok := renderPromptParameterNode("item", item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
case string:
|
||||
return indent + `<parameter name="content">` + renderPromptXMLText(v) + `</parameter>`, true
|
||||
default:
|
||||
return indent + `<parameter name="value">` + renderPromptXMLText(fmt.Sprint(v)) + `</parameter>`, true
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptParameterNode(name string, value any, indent string) (string, bool) {
|
||||
trimmedName := strings.TrimSpace(name)
|
||||
if trimmedName == "" {
|
||||
return "", false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + `"></parameter>`, true
|
||||
case map[string]any:
|
||||
body, ok := renderPromptToolXMLBody(v, indent+" ")
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + `"></parameter>`, true
|
||||
}
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + "\">\n" + body + "\n" + indent + `</parameter>`, true
|
||||
case []any:
|
||||
body, ok := renderPromptToolXMLArray(v, indent+" ")
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + `"></parameter>`, true
|
||||
}
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + "\">\n" + body + "\n" + indent + `</parameter>`, true
|
||||
case string:
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + `">` + renderPromptXMLText(v) + `</parameter>`, true
|
||||
default:
|
||||
return indent + `<parameter name="` + escapeXMLAttribute(trimmedName) + `">` + renderPromptXMLText(fmt.Sprint(v)) + `</parameter>`, true
|
||||
}
|
||||
return " <param><content>" + renderPromptXMLText(fallback) + "</content></param>"
|
||||
}
|
||||
|
||||
func normalizePromptToolCallValue(raw any) any {
|
||||
@@ -246,6 +317,18 @@ func isValidPromptXMLName(name string) bool {
|
||||
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
func escapeXMLAttribute(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.NewReplacer(
|
||||
"&", "&",
|
||||
`"`, """,
|
||||
"<", "<",
|
||||
">", ">",
|
||||
).Replace(text)
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestFormatToolCallsForPromptXML(t *testing.T) {
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty formatted tool calls")
|
||||
}
|
||||
if got != "<tools>\n <tool_call>\n <tool_name>search_web</tool_name>\n <param>\n <query><![CDATA[latest]]></query>\n </param>\n </tool_call>\n</tools>" {
|
||||
if got != "<tool_calls>\n <invoke name=\"search_web\">\n <parameter name=\"query\"><![CDATA[latest]]></parameter>\n </invoke>\n</tool_calls>" {
|
||||
t.Fatalf("unexpected formatted tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func TestFormatToolCallsForPromptEscapesXMLEntities(t *testing.T) {
|
||||
"arguments": `{"q":"a < b && c > d"}`,
|
||||
},
|
||||
})
|
||||
want := "<tools>\n <tool_call>\n <tool_name>search<&></tool_name>\n <param>\n <q><![CDATA[a < b && c > d]]></q>\n </param>\n </tool_call>\n</tools>"
|
||||
want := "<tool_calls>\n <invoke name=\"search<&>\">\n <parameter name=\"q\"><![CDATA[a < b && c > d]]></parameter>\n </invoke>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected escaped tool call XML: %q", got)
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func TestFormatToolCallsForPromptUsesCDATAForMultilineContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
want := "<tools>\n <tool_call>\n <tool_name>write_file</tool_name>\n <param>\n <content><![CDATA[#!/bin/bash\nprintf \"hello\"\n]]></content>\n <path><![CDATA[script.sh]]></path>\n </param>\n </tool_call>\n</tools>"
|
||||
want := "<tool_calls>\n <invoke name=\"write_file\">\n <parameter name=\"content\"><![CDATA[#!/bin/bash\nprintf \"hello\"\n]]></parameter>\n <parameter name=\"path\"><![CDATA[script.sh]]></parameter>\n </invoke>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected multiline cdata tool call XML: %q", got)
|
||||
}
|
||||
|
||||
@@ -56,6 +56,21 @@ func TestCollectStreamThinkingAndText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamDropsThinkingWhenDisabled(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, true)
|
||||
if result.Thinking != "" {
|
||||
t.Fatalf("expected disabled thinking to be dropped, got %q", result.Thinking)
|
||||
}
|
||||
if result.Text != "Answer" {
|
||||
t.Fatalf("expected only visible answer, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamOnlyThinking(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" +
|
||||
|
||||
@@ -99,6 +99,10 @@ func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, current
|
||||
if transitioned {
|
||||
newType = "text"
|
||||
}
|
||||
if !thinkingEnabled {
|
||||
parts = dropThinkingParts(parts)
|
||||
newType = "text"
|
||||
}
|
||||
return parts, false, newType
|
||||
}
|
||||
|
||||
@@ -172,6 +176,9 @@ func updateTypeFromNestedResponse(path string, v any, newType *string) {
|
||||
func resolvePartType(path string, thinkingEnabled bool, newType string) string {
|
||||
switch {
|
||||
case path == "response/thinking_content":
|
||||
if !thinkingEnabled {
|
||||
return "thinking"
|
||||
}
|
||||
if newType == "text" {
|
||||
return "text"
|
||||
}
|
||||
@@ -187,6 +194,20 @@ func resolvePartType(path string, thinkingEnabled bool, newType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func dropThinkingParts(parts []ContentPart) []ContentPart {
|
||||
if len(parts) == 0 {
|
||||
return parts
|
||||
}
|
||||
out := parts[:0]
|
||||
for _, p := range parts {
|
||||
if p.Type == "thinking" {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func appendChunkValueContent(v any, partType string, newType *string, parts *[]ContentPart, path string) bool {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
|
||||
@@ -13,18 +13,18 @@ func TestRegression_RobustXMLAndCDATA(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Standard JSON parameters (Regression)",
|
||||
text: `<tools><tool_call><tool_name>foo</tool_name><param>{"a": 1}</param></tool_call></tools>`,
|
||||
expected: []ParsedToolCall{{Name: "foo", Input: map[string]any{"a": float64(1)}}},
|
||||
text: `<tool_calls><invoke name="foo"><parameter name="a">1</parameter></invoke></tool_calls>`,
|
||||
expected: []ParsedToolCall{{Name: "foo", Input: map[string]any{"a": "1"}}},
|
||||
},
|
||||
{
|
||||
name: "XML tags parameters (Regression)",
|
||||
text: `<tools><tool_call><tool_name>foo</tool_name><param><arg1>hello</arg1></param></tool_call></tools>`,
|
||||
text: `<tool_calls><invoke name="foo"><parameter name="arg1">hello</parameter></invoke></tool_calls>`,
|
||||
expected: []ParsedToolCall{{Name: "foo", Input: map[string]any{"arg1": "hello"}}},
|
||||
},
|
||||
{
|
||||
name: "CDATA parameters (New Feature)",
|
||||
text: `<tools><tool_call><tool_name>write_file</tool_name><param><content><![CDATA[line 1
|
||||
line 2 with <tags> and & symbols]]></content></param></tool_call></tools>`,
|
||||
text: `<tool_calls><invoke name="write_file"><parameter name="content"><![CDATA[line 1
|
||||
line 2 with <tags> and & symbols]]></parameter></invoke></tool_calls>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "write_file",
|
||||
Input: map[string]any{"content": "line 1\nline 2 with <tags> and & symbols"},
|
||||
@@ -32,9 +32,9 @@ line 2 with <tags> and & symbols]]></content></param></tool_call></tools>`,
|
||||
},
|
||||
{
|
||||
name: "Nested XML with repeated parameters (New Feature)",
|
||||
text: `<tools><tool_call><tool_name>write_file</tool_name><param><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
text: `<tool_calls><invoke name="write_file"><parameter name="path">script.sh</parameter><parameter name="content"><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></param></tool_call></tools>`,
|
||||
]]></parameter><parameter name="item">first</parameter><parameter name="item">second</parameter></invoke></tool_calls>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "write_file",
|
||||
Input: map[string]any{
|
||||
@@ -46,7 +46,7 @@ echo "hello"
|
||||
},
|
||||
{
|
||||
name: "Dirty XML with unescaped symbols (Robustness Improvement)",
|
||||
text: `<tools><tool_call><tool_name>bash</tool_name><param><command>echo "hello" > out.txt && cat out.txt</command></param></tool_call></tools>`,
|
||||
text: `<tool_calls><invoke name="bash"><parameter name="command">echo "hello" > out.txt && cat out.txt</parameter></invoke></tool_calls>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "bash",
|
||||
Input: map[string]any{"command": "echo \"hello\" > out.txt && cat out.txt"},
|
||||
@@ -54,7 +54,7 @@ echo "hello"
|
||||
},
|
||||
{
|
||||
name: "Mixed JSON inside CDATA (New Hybrid Case)",
|
||||
text: `<tools><tool_call><tool_name>foo</tool_name><param><![CDATA[{"json_param": "works"}]]></param></tool_call></tools>`,
|
||||
text: `<tool_calls><invoke name="foo"><parameter name="json_param"><![CDATA[works]]></parameter></invoke></tool_calls>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "foo",
|
||||
Input: map[string]any{"json_param": "works"},
|
||||
|
||||
@@ -36,93 +36,139 @@ func BuildToolCallInstructions(toolNames []string) string {
|
||||
|
||||
return `TOOL CALL FORMAT — FOLLOW EXACTLY:
|
||||
|
||||
<tools>
|
||||
<tool_call>
|
||||
<tool_name>TOOL_NAME_HERE</tool_name>
|
||||
<param>
|
||||
<PARAMETER_NAME><![CDATA[PARAMETER_VALUE]]></PARAMETER_NAME>
|
||||
</param>
|
||||
</tool_call>
|
||||
</tools>
|
||||
<tool_calls>
|
||||
<invoke name="TOOL_NAME_HERE">
|
||||
<parameter name="PARAMETER_NAME"><![CDATA[PARAMETER_VALUE]]></parameter>
|
||||
</invoke>
|
||||
</tool_calls>
|
||||
|
||||
RULES:
|
||||
1) Use the <tools> XML wrapper format only.
|
||||
2) Put one or more <tool_call> entries under a single <tools> root.
|
||||
3) Use <tool_name> for the tool name and <param> for the argument container.
|
||||
1) Use the <tool_calls> XML wrapper format only.
|
||||
2) Put one or more <invoke> entries under a single <tool_calls> root.
|
||||
3) Put the tool name in the invoke name attribute: <invoke name="TOOL_NAME">.
|
||||
4) All string values must use <![CDATA[...]]>, even short ones. This includes code, scripts, file contents, prompts, paths, names, and queries.
|
||||
5) Objects use nested XML elements. Arrays may repeat the same tag or use <item> children.
|
||||
6) Numbers, booleans, and null stay plain text.
|
||||
7) Use only the parameter names in the tool schema. Do not invent fields.
|
||||
8) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
|
||||
5) Every top-level argument must be a <parameter name="ARG_NAME">...</parameter> node.
|
||||
6) Objects use nested XML elements inside the parameter body. Arrays may repeat <item> children.
|
||||
7) Numbers, booleans, and null stay plain text.
|
||||
8) Use only the parameter names in the tool schema. Do not invent fields.
|
||||
9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
|
||||
|
||||
PARAMETER SHAPES:
|
||||
- string => <name><![CDATA[value]]></name>
|
||||
- object => nested XML elements
|
||||
- array => repeated tags or <item> children
|
||||
- number/bool/null => plain text
|
||||
- string => <parameter name="x"><![CDATA[value]]></parameter>
|
||||
- object => <parameter name="x"><field>...</field></parameter>
|
||||
- array => <parameter name="x"><item>...</item><item>...</item></parameter>
|
||||
- number/bool/null => <parameter name="x">plain_text</parameter>
|
||||
|
||||
【WRONG — Do NOT do these】:
|
||||
|
||||
Wrong 1 — mixed text after XML:
|
||||
<tools>...</tools> I hope this helps.
|
||||
Wrong 2 — JSON payload inside <param>:
|
||||
<tool_call><tool_name>` + ex1 + `</tool_name><param>{"path":"x"}</param></tool_call>
|
||||
Wrong 3 — Markdown code fences:
|
||||
<tool_calls>...</tool_calls> I hope this helps.
|
||||
Wrong 2 — Markdown code fences:
|
||||
` + "```xml" + `
|
||||
<tools>...</tools>
|
||||
<tool_calls>...</tool_calls>
|
||||
` + "```" + `
|
||||
|
||||
Remember: The ONLY valid way to use tools is the <tools>...</tools> XML block at the end of your response.
|
||||
Remember: The ONLY valid way to use tools is the <tool_calls>...</tool_calls> XML block at the end of your response.
|
||||
|
||||
【CORRECT EXAMPLES】:
|
||||
|
||||
Example A — Single tool:
|
||||
<tools>
|
||||
<tool_call>
|
||||
<tool_name>` + ex1 + `</tool_name>
|
||||
<param>` + ex1Params + `</param>
|
||||
</tool_call>
|
||||
</tools>
|
||||
<tool_calls>
|
||||
<invoke name="` + ex1 + `">
|
||||
` + indentPromptParameters(ex1Params, " ") + `
|
||||
</invoke>
|
||||
</tool_calls>
|
||||
|
||||
Example B — Two tools in parallel:
|
||||
<tools>
|
||||
<tool_call>
|
||||
<tool_name>` + ex1 + `</tool_name>
|
||||
<param>` + ex1Params + `</param>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<tool_name>` + ex2 + `</tool_name>
|
||||
<param>` + ex2Params + `</param>
|
||||
</tool_call>
|
||||
</tools>
|
||||
<tool_calls>
|
||||
<invoke name="` + ex1 + `">
|
||||
` + indentPromptParameters(ex1Params, " ") + `
|
||||
</invoke>
|
||||
<invoke name="` + ex2 + `">
|
||||
` + indentPromptParameters(ex2Params, " ") + `
|
||||
</invoke>
|
||||
<invoke name="Read">
|
||||
<parameter name="file_path">` + promptCDATA("/abs/path/to/another-file.txt") + `</parameter>
|
||||
</invoke>
|
||||
</tool_calls>
|
||||
|
||||
Example C — Tool with nested XML parameters:
|
||||
<tools>
|
||||
<tool_call>
|
||||
<tool_name>` + ex3 + `</tool_name>
|
||||
<param>` + ex3Params + `</param>
|
||||
</tool_call>
|
||||
</tools>
|
||||
|
||||
<tool_calls>
|
||||
<invoke name="` + ex3 + `">
|
||||
` + indentPromptParameters(ex3Params, " ") + `
|
||||
</invoke>
|
||||
</tool_calls>
|
||||
|
||||
Example D — Tool with long script using CDATA (RELIABLE FOR CODE/SCRIPTS):
|
||||
<tools>
|
||||
<tool_call>
|
||||
<tool_name>` + ex2 + `</tool_name>
|
||||
<param>
|
||||
<path>` + promptCDATA("script.sh") + `</path>
|
||||
<content><![CDATA[
|
||||
<tool_calls>
|
||||
<invoke name="` + ex2 + `">
|
||||
<parameter name="path">` + promptCDATA("script.sh") + `</parameter>
|
||||
<parameter name="content"><![CDATA[
|
||||
#!/bin/bash
|
||||
if [ "$1" == "test" ]; then
|
||||
echo "Success!"
|
||||
fi
|
||||
]]></content>
|
||||
</param>
|
||||
</tool_call>
|
||||
</tools>
|
||||
]]></parameter>
|
||||
</invoke>
|
||||
</tool_calls>
|
||||
|
||||
`
|
||||
}
|
||||
|
||||
func indentPromptParameters(body, indent string) string {
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return indent + `<parameter name="content"></parameter>`
|
||||
}
|
||||
lines := strings.Split(body, "\n")
|
||||
for i, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
lines[i] = line
|
||||
continue
|
||||
}
|
||||
lines[i] = indent + line
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func wrapParameter(name, inner string) string {
|
||||
return `<parameter name="` + name + `">` + inner + `</parameter>`
|
||||
}
|
||||
|
||||
func exampleReadParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Read":
|
||||
return wrapParameter("file_path", promptCDATA("README.md"))
|
||||
case "Glob":
|
||||
return wrapParameter("pattern", promptCDATA("**/*.go")) + "\n" + wrapParameter("path", promptCDATA("."))
|
||||
default:
|
||||
return wrapParameter("path", promptCDATA("src/main.go"))
|
||||
}
|
||||
}
|
||||
|
||||
func exampleWriteOrExecParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Bash", "execute_command":
|
||||
return wrapParameter("command", promptCDATA("pwd"))
|
||||
case "exec_command":
|
||||
return wrapParameter("cmd", promptCDATA("pwd"))
|
||||
case "Edit":
|
||||
return wrapParameter("file_path", promptCDATA("README.md")) + "\n" + wrapParameter("old_string", promptCDATA("foo")) + "\n" + wrapParameter("new_string", promptCDATA("bar"))
|
||||
case "MultiEdit":
|
||||
return wrapParameter("file_path", promptCDATA("README.md")) + "\n" + `<parameter name="edits"><item><old_string>` + promptCDATA("foo") + `</old_string><new_string>` + promptCDATA("bar") + `</new_string></item></parameter>`
|
||||
default:
|
||||
return wrapParameter("path", promptCDATA("output.txt")) + "\n" + wrapParameter("content", promptCDATA("Hello world"))
|
||||
}
|
||||
}
|
||||
|
||||
func exampleInteractiveParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Task":
|
||||
return wrapParameter("description", promptCDATA("Investigate flaky tests")) + "\n" + wrapParameter("prompt", promptCDATA("Run targeted tests and summarize failures"))
|
||||
default:
|
||||
return wrapParameter("question", promptCDATA("Which approach do you prefer?")) + "\n" + `<parameter name="follow_up"><item><text>` + promptCDATA("Option A") + `</text></item><item><text>` + promptCDATA("Option B") + `</text></item></parameter>`
|
||||
}
|
||||
}
|
||||
|
||||
func matchAny(name string, candidates ...string) bool {
|
||||
for _, c := range candidates {
|
||||
if name == c {
|
||||
@@ -132,41 +178,6 @@ func matchAny(name string, candidates ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func exampleReadParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Read":
|
||||
return `<file_path>` + promptCDATA("README.md") + `</file_path>`
|
||||
case "Glob":
|
||||
return `<pattern>` + promptCDATA("**/*.go") + `</pattern><path>` + promptCDATA(".") + `</path>`
|
||||
default:
|
||||
return `<path>` + promptCDATA("src/main.go") + `</path>`
|
||||
}
|
||||
}
|
||||
|
||||
func exampleWriteOrExecParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Bash", "execute_command":
|
||||
return `<command>` + promptCDATA("pwd") + `</command>`
|
||||
case "exec_command":
|
||||
return `<cmd>` + promptCDATA("pwd") + `</cmd>`
|
||||
case "Edit":
|
||||
return `<file_path>` + promptCDATA("README.md") + `</file_path><old_string>` + promptCDATA("foo") + `</old_string><new_string>` + promptCDATA("bar") + `</new_string>`
|
||||
case "MultiEdit":
|
||||
return `<file_path>` + promptCDATA("README.md") + `</file_path><edits><old_string>` + promptCDATA("foo") + `</old_string><new_string>` + promptCDATA("bar") + `</new_string></edits>`
|
||||
default:
|
||||
return `<path>` + promptCDATA("output.txt") + `</path><content>` + promptCDATA("Hello world") + `</content>`
|
||||
}
|
||||
}
|
||||
|
||||
func exampleInteractiveParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Task":
|
||||
return `<description>` + promptCDATA("Investigate flaky tests") + `</description><prompt>` + promptCDATA("Run targeted tests and summarize failures") + `</prompt>`
|
||||
default:
|
||||
return `<question>` + promptCDATA("Which approach do you prefer?") + `</question><follow_up><text>` + promptCDATA("Option A") + `</text></follow_up><follow_up><text>` + promptCDATA("Option B") + `</text></follow_up>`
|
||||
}
|
||||
}
|
||||
|
||||
func promptCDATA(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
|
||||
@@ -7,20 +7,20 @@ import (
|
||||
|
||||
func TestBuildToolCallInstructions_ExecCommandUsesCmdExample(t *testing.T) {
|
||||
out := BuildToolCallInstructions([]string{"exec_command"})
|
||||
if !strings.Contains(out, `<tool_name>exec_command</tool_name>`) {
|
||||
if !strings.Contains(out, `<invoke name="exec_command">`) {
|
||||
t.Fatalf("expected exec_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<param><cmd><![CDATA[pwd]]></cmd></param>`) {
|
||||
if !strings.Contains(out, `<parameter name="cmd"><![CDATA[pwd]]></parameter>`) {
|
||||
t.Fatalf("expected cmd parameter example for exec_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T) {
|
||||
out := BuildToolCallInstructions([]string{"execute_command"})
|
||||
if !strings.Contains(out, `<tool_name>execute_command</tool_name>`) {
|
||||
if !strings.Contains(out, `<invoke name="execute_command">`) {
|
||||
t.Fatalf("expected execute_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<param><command><![CDATA[pwd]]></command></param>`) {
|
||||
if !strings.Contains(out, `<parameter name="command"><![CDATA[pwd]]></parameter>`) {
|
||||
t.Fatalf("expected command parameter example for execute_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,12 +74,7 @@ func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []strin
|
||||
|
||||
func looksLikeToolCallSyntax(text string) bool {
|
||||
lower := strings.ToLower(text)
|
||||
return strings.Contains(lower, "<tools") ||
|
||||
strings.Contains(lower, "<tool_call") ||
|
||||
strings.Contains(lower, "<attempt_completion") ||
|
||||
strings.Contains(lower, "<ask_followup_question") ||
|
||||
strings.Contains(lower, "<new_task") ||
|
||||
strings.Contains(lower, "<result")
|
||||
return strings.Contains(lower, "<tool_calls")
|
||||
}
|
||||
|
||||
func stripFencedCodeBlocks(text string) string {
|
||||
|
||||
@@ -7,12 +7,13 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var xmlToolsWrapperPattern = regexp.MustCompile(`(?is)<tools\b[^>]*>\s*(.*?)\s*</tools>`)
|
||||
var xmlToolCallPattern = regexp.MustCompile(`(?is)<tool_call\b[^>]*>\s*(.*?)\s*</tool_call>`)
|
||||
var xmlCanonicalToolCallBodyPattern = regexp.MustCompile(`(?is)^\s*<(?:[a-z0-9_:-]+:)?tool_name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?tool_name>\s*<(?:[a-z0-9_:-]+:)?param\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?param>\s*$`)
|
||||
var xmlToolCallsWrapperPattern = regexp.MustCompile(`(?is)<tool_calls\b[^>]*>\s*(.*?)\s*</tool_calls>`)
|
||||
var xmlInvokePattern = regexp.MustCompile(`(?is)<invoke\b([^>]*)>\s*(.*?)\s*</invoke>`)
|
||||
var xmlParameterPattern = regexp.MustCompile(`(?is)<parameter\b([^>]*)>\s*(.*?)\s*</parameter>`)
|
||||
var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`)
|
||||
|
||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
wrappers := xmlToolsWrapperPattern.FindAllStringSubmatch(text, -1)
|
||||
wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1)
|
||||
if len(wrappers) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -21,7 +22,7 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
if len(wrapper) < 2 {
|
||||
continue
|
||||
}
|
||||
for _, block := range xmlToolCallPattern.FindAllString(wrapper[1], -1) {
|
||||
for _, block := range xmlInvokePattern.FindAllStringSubmatch(wrapper[1], -1) {
|
||||
call, ok := parseSingleXMLToolCall(block)
|
||||
if !ok {
|
||||
continue
|
||||
@@ -35,37 +36,90 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
inner := strings.TrimSpace(block)
|
||||
inner = strings.TrimPrefix(inner, "<tool_call>")
|
||||
inner = strings.TrimSuffix(inner, "</tool_call>")
|
||||
inner = strings.TrimSpace(inner)
|
||||
func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
|
||||
if len(block) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
attrs := parseXMLTagAttributes(block[1])
|
||||
name := strings.TrimSpace(html.UnescapeString(attrs["name"]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
|
||||
inner := strings.TrimSpace(block[2])
|
||||
if strings.HasPrefix(inner, "{") {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
if name != "" {
|
||||
input := map[string]any{}
|
||||
if params, ok := payload["input"].(map[string]any); ok {
|
||||
input := map[string]any{}
|
||||
if params, ok := payload["input"].(map[string]any); ok {
|
||||
input = params
|
||||
}
|
||||
if len(input) == 0 {
|
||||
if params, ok := payload["parameters"].(map[string]any); ok {
|
||||
input = params
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
|
||||
m := xmlCanonicalToolCallBodyPattern.FindStringSubmatch(inner)
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
input := map[string]any{}
|
||||
for _, paramMatch := range xmlParameterPattern.FindAllStringSubmatch(inner, -1) {
|
||||
if len(paramMatch) < 3 {
|
||||
continue
|
||||
}
|
||||
paramAttrs := parseXMLTagAttributes(paramMatch[1])
|
||||
paramName := strings.TrimSpace(html.UnescapeString(paramAttrs["name"]))
|
||||
if paramName == "" {
|
||||
continue
|
||||
}
|
||||
value := parseInvokeParameterValue(paramMatch[2])
|
||||
appendMarkupValue(input, paramName, value)
|
||||
}
|
||||
name := strings.TrimSpace(html.UnescapeString(extractRawTagValue(m[1])))
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
|
||||
if len(input) == 0 {
|
||||
if strings.TrimSpace(inner) != "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: map[string]any{}}, true
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: parseStructuredToolCallInput(m[2])}, true
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
func parseXMLTagAttributes(raw string) map[string]string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return map[string]string{}
|
||||
}
|
||||
out := map[string]string{}
|
||||
for _, m := range xmlAttrPattern.FindAllStringSubmatch(raw, -1) {
|
||||
if len(m) < 5 {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(m[1]))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
value := m[3]
|
||||
if value == "" {
|
||||
value = m[4]
|
||||
}
|
||||
out[key] = value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseInvokeParameterValue(raw string) any {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if parsed := parseStructuredToolCallInput(trimmed); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if rawValue, ok := parsed["_raw"].(string); ok {
|
||||
return rawValue
|
||||
}
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
return html.UnescapeString(extractRawTagValue(trimmed))
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsToolsWrapper(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>Bash</tool_name><param><command>pwd</command><description>show cwd</description></param></tool_call></tools>`
|
||||
func TestParseToolCallsSupportsToolCallsWrapper(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="Bash"><parameter name="command">pwd</parameter><parameter name="description">show cwd</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -31,9 +31,9 @@ func TestParseToolCallsSupportsToolsWrapper(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsStandaloneToolWithMultilineCDATAAndRepeatedXMLTags(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>write_file</tool_name><param><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
text := `<tool_calls><invoke name="write_file"><parameter name="path">script.sh</parameter><parameter name="content"><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></param></tool_call></tools>`
|
||||
]]></parameter><parameter name="item">first</parameter><parameter name="item">second</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"write_file"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -54,8 +54,8 @@ echo "hello"
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsCanonicalParamsJSON(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>get_weather</tool_name><param>{"city":"beijing","unit":"c"}</param></tool_call></tools>`
|
||||
func TestParseToolCallsSupportsInvokeParameters(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="get_weather"><parameter name="city">beijing</parameter><parameter name="unit">c</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"get_weather"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -69,7 +69,7 @@ func TestParseToolCallsSupportsCanonicalParamsJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsPreservesRawMalformedParams(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>execute_command</tool_name><param>cd /root && git status</param></tool_call></tools>`
|
||||
text := `<tool_calls><invoke name="execute_command"><parameter name="command">cd /root && git status</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -77,9 +77,9 @@ func TestParseToolCallsPreservesRawMalformedParams(t *testing.T) {
|
||||
if calls[0].Name != "execute_command" {
|
||||
t.Fatalf("expected tool name execute_command, got %q", calls[0].Name)
|
||||
}
|
||||
raw, ok := calls[0].Input["_raw"].(string)
|
||||
raw, ok := calls[0].Input["command"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected raw argument tracking, got %#v", calls[0].Input)
|
||||
t.Fatalf("expected raw command tracking, got %#v", calls[0].Input)
|
||||
}
|
||||
if raw != "cd /root && git status" {
|
||||
t.Fatalf("expected raw arguments to be preserved, got %q", raw)
|
||||
@@ -87,7 +87,7 @@ func TestParseToolCallsPreservesRawMalformedParams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsParamsJSONWithAmpersandCommand(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>execute_command</tool_name><param>{"command":"sshpass -p 'xxx' ssh -o StrictHostKeyChecking=no -p 1111 root@111.111.111.111 'cd /root && git clone https://github.com/ericc-ch/copilot-api.git'","cwd":null,"timeout":null}</param></tool_call></tools>`
|
||||
text := `<tool_calls><invoke name="execute_command"><parameter name="command">sshpass -p 'xxx' ssh -o StrictHostKeyChecking=no -p 1111 root@111.111.111.111 'cd /root && git clone https://github.com/ericc-ch/copilot-api.git'</parameter><parameter name="cwd"></parameter><parameter name="timeout"></parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -102,7 +102,7 @@ func TestParseToolCallsSupportsParamsJSONWithAmpersandCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotTreatParamsNameTagAsToolName(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>execute_command</tool_name><param><tool_name>file.txt</tool_name><command>pwd</command></param></tool_call></tools>`
|
||||
text := `<tool_calls><invoke name="execute_command"><parameter name="tool_name">file.txt</parameter><parameter name="command">pwd</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -115,8 +115,8 @@ func TestParseToolCallsDoesNotTreatParamsNameTagAsToolName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDetailedMarksToolsSyntax(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>Bash</tool_name><param><command>pwd</command></param></tool_call></tools>`
|
||||
func TestParseToolCallsDetailedMarksToolCallsSyntax(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></tool_calls>`
|
||||
res := ParseToolCallsDetailed(text, []string{"bash"})
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||
@@ -127,7 +127,7 @@ func TestParseToolCallsDetailedMarksToolsSyntax(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsInlineJSONToolObject(t *testing.T) {
|
||||
text := `<tools><tool_call>{"name":"Bash","input":{"command":"pwd","description":"show cwd"}}</tool_call></tools>`
|
||||
text := `<tool_calls><invoke name="Bash">{"input":{"command":"pwd","description":"show cwd"}}</invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
@@ -141,7 +141,7 @@ func TestParseToolCallsSupportsInlineJSONToolObject(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>read_file</function><param>{"path":"README.md"}</param></tool_call></tools>`
|
||||
text := `<tool_calls><invoke name="read_file"><parameter name="path">README.md</function></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected mismatched tags to be rejected, got %#v", calls)
|
||||
@@ -149,26 +149,37 @@ func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotTreatNameInsideParamsAsToolName(t *testing.T) {
|
||||
text := `<tools><tool_call><param><tool_name>data_only</tool_name><path>README.md</path></param></tool_call></tools>`
|
||||
text := `<tool_calls><invoke><parameter name="path">README.md</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool call when name appears only under params, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectsLegacyToolCallsRoot(t *testing.T) {
|
||||
text := `<tool_calls><tool_call><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></tool_call></tool_calls>`
|
||||
func TestParseToolCallsRejectsLegacyToolsWrapper(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></tool_call></tools>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected legacy tool_calls root to be rejected, got %#v", calls)
|
||||
t.Fatalf("expected legacy tools wrapper to be rejected, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectsLegacyParametersTag(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>read_file</tool_name><parameters>{"path":"README.md"}</parameters></tool_call></tools>`
|
||||
func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) {
|
||||
text := `<invoke name="read_file"><parameter name="path">README.md</parameter></invoke>`
|
||||
res := ParseToolCallsDetailed(text, []string{"read_file"})
|
||||
if len(res.Calls) != 0 {
|
||||
t.Fatalf("expected bare invoke to be rejected, got %#v", res.Calls)
|
||||
}
|
||||
if res.SawToolCallSyntax {
|
||||
t.Fatalf("expected bare invoke to no longer count as supported syntax, got %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="read_file"><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected legacy parameters tag to be rejected, got %#v", calls)
|
||||
t.Fatalf("expected legacy canonical body to be rejected, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,7 +321,7 @@ func TestRepairLooseJSONWithNestedObjects(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) {
|
||||
text := `<tools><tool_call><tool_name>Bash</tool_name><param>{"command":"echo a > out.txt"}</param></tool_call></tools>`
|
||||
text := `<tool_calls><invoke name="Bash"><parameter name="command">echo a > out.txt</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one call, got %#v", calls)
|
||||
@@ -322,7 +333,7 @@ func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsIgnoresXMLInsideFencedCodeBlock(t *testing.T) {
|
||||
text := "Here is an example:\n```xml\n<tools><tool_call><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\nDo not execute it."
|
||||
text := "Here is an example:\n```xml\n<tool_calls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\nDo not execute it."
|
||||
res := ParseToolCallsDetailed(text, []string{"read_file"})
|
||||
if len(res.Calls) != 0 {
|
||||
t.Fatalf("expected no parsed calls for fenced example, got %#v", res.Calls)
|
||||
@@ -330,7 +341,7 @@ func TestParseToolCallsIgnoresXMLInsideFencedCodeBlock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsParsesOnlyNonFencedXMLToolCall(t *testing.T) {
|
||||
text := "```xml\n<tools><tool_call><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\n<tools><tool_call><tool_name>search</tool_name><param>{\"q\":\"golang\"}</param></tool_call></tools>"
|
||||
text := "```xml\n<tool_calls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\n<tool_calls><invoke name=\"search\"><parameter name=\"q\">golang</parameter></invoke></tool_calls>"
|
||||
res := ParseToolCallsDetailed(text, []string{"read_file", "search"})
|
||||
if len(res.Calls) != 1 {
|
||||
t.Fatalf("expected exactly one parsed call outside fence, got %#v", res.Calls)
|
||||
@@ -341,7 +352,7 @@ func TestParseToolCallsParsesOnlyNonFencedXMLToolCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolCallsParsesAfterFourBacktickFence(t *testing.T) {
|
||||
text := "````markdown\n```xml\n<tools><tool_call><tool_name>read_file</tool_name><param>{\"path\":\"README.md\"}</param></tool_call></tools>\n```\n````\n<tools><tool_call><tool_name>search</tool_name><param>{\"q\":\"outside\"}</param></tool_call></tools>"
|
||||
text := "````markdown\n```xml\n<tool_calls><invoke name=\"read_file\"><parameter name=\"path\">README.md</parameter></invoke></tool_calls>\n```\n````\n<tool_calls><invoke name=\"search\"><parameter name=\"q\">outside</parameter></invoke></tool_calls>"
|
||||
res := ParseToolCallsDetailed(text, []string{"read_file", "search"})
|
||||
if len(res.Calls) != 1 {
|
||||
t.Fatalf("expected exactly one parsed call outside four-backtick fence, got %#v", res.Calls)
|
||||
|
||||
@@ -3,27 +3,48 @@ package util
|
||||
import "strings"
|
||||
|
||||
func ResolveThinkingEnabled(req map[string]any, defaultEnabled bool) bool {
|
||||
if enabled, ok := parseThinkingSetting(req["thinking"]); ok {
|
||||
return enabled
|
||||
}
|
||||
if extraBody, ok := req["extra_body"].(map[string]any); ok {
|
||||
if enabled, ok := parseThinkingSetting(extraBody["thinking"]); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if enabled, ok := parseReasoningEffort(req["reasoning_effort"]); ok {
|
||||
if enabled, ok := ResolveThinkingOverride(req); ok {
|
||||
return enabled
|
||||
}
|
||||
return defaultEnabled
|
||||
}
|
||||
|
||||
func ResolveThinkingOverride(req map[string]any) (bool, bool) {
|
||||
if req == nil {
|
||||
return false, false
|
||||
}
|
||||
if enabled, ok := parseThinkingSetting(req["thinking"]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
if enabled, ok := parseReasoningSetting(req["reasoning"]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
if extraBody, ok := req["extra_body"].(map[string]any); ok {
|
||||
if enabled, ok := parseThinkingSetting(extraBody["thinking"]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
if enabled, ok := parseReasoningSetting(extraBody["reasoning"]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
if enabled, ok := parseReasoningEffort(extraBody["reasoning_effort"]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
}
|
||||
if enabled, ok := parseReasoningEffort(req["reasoning_effort"]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func parseThinkingSetting(raw any) (bool, bool) {
|
||||
switch v := raw.(type) {
|
||||
case bool:
|
||||
return v, true
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "enabled":
|
||||
case "enabled", "enable", "on", "true":
|
||||
return true, true
|
||||
case "disabled":
|
||||
case "disabled", "disable", "off", "false", "none":
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
@@ -36,10 +57,28 @@ func parseThinkingSetting(raw any) (bool, bool) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
func parseReasoningSetting(raw any) (bool, bool) {
|
||||
switch v := raw.(type) {
|
||||
case bool:
|
||||
return v, true
|
||||
case string:
|
||||
return parseReasoningEffort(v)
|
||||
case map[string]any:
|
||||
for _, key := range []string{"effort", "type", "enabled"} {
|
||||
if enabled, ok := parseReasoningSetting(v[key]); ok {
|
||||
return enabled, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func parseReasoningEffort(raw any) (bool, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(toString(raw))) {
|
||||
case "low", "medium", "high", "xhigh":
|
||||
case "minimal", "low", "medium", "high", "xhigh":
|
||||
return true, true
|
||||
case "none", "disabled", "disable", "off", "false":
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
|
||||
@@ -27,13 +27,24 @@ func TestResolveThinkingEnabledUsesExtraBodyFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolveThinkingEnabledMapsReasoningEffortToEnabled(t *testing.T) {
|
||||
for _, effort := range []string{"low", "medium", "high", "xhigh"} {
|
||||
for _, effort := range []string{"minimal", "low", "medium", "high", "xhigh"} {
|
||||
if got := ResolveThinkingEnabled(map[string]any{"reasoning_effort": effort}, false); !got {
|
||||
t.Fatalf("expected reasoning_effort=%s to enable thinking", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveThinkingEnabledMapsReasoningObject(t *testing.T) {
|
||||
req := map[string]any{"reasoning": map[string]any{"effort": "none"}}
|
||||
if got := ResolveThinkingEnabled(req, true); got {
|
||||
t.Fatalf("expected reasoning.effort=none to disable thinking")
|
||||
}
|
||||
req = map[string]any{"reasoning": map[string]any{"effort": "medium"}}
|
||||
if got := ResolveThinkingEnabled(req, false); !got {
|
||||
t.Fatalf("expected reasoning.effort=medium to enable thinking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveThinkingEnabledDefaultsWhenUnset(t *testing.T) {
|
||||
if !ResolveThinkingEnabled(nil, true) {
|
||||
t.Fatal("expected default thinking=true when unset")
|
||||
|
||||
Reference in New Issue
Block a user