From ade648033d58882beba25c7a31d7c39cc2f135a8 Mon Sep 17 00:00:00 2001 From: CJACK Date: Sun, 5 Apr 2026 19:22:43 +0800 Subject: [PATCH] refactor: rename tool XML wrapper from tool_calls to tool_batch and add schema attention blocks to tool prompts --- internal/adapter/claude/handler_util_test.go | 7 +- internal/adapter/claude/handler_utils.go | 3 +- .../adapter/openai/handler_toolcall_format.go | 7 +- internal/adapter/openai/prompt_build_test.go | 18 ++- internal/util/tool_prompt.go | 120 +++++++++++++++++- internal/util/tool_prompt_test.go | 25 ++++ 6 files changed, 165 insertions(+), 15 deletions(-) diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index 82302f0..3b45d61 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -319,7 +319,8 @@ func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) { "name": "search", "description": "Search via function tool", "parameters": map[string]any{ - "type": "object", + "type": "object", + "required": []any{"q"}, "properties": map[string]any{ "q": map[string]any{"type": "string"}, }, @@ -334,8 +335,8 @@ func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) { if !containsStr(prompt, "Search via function tool") { t.Fatalf("expected OpenAI-style function tool description in prompt, got: %q", prompt) } - if !containsStr(prompt, "\"q\"") { - t.Fatalf("expected parameters schema serialized in prompt, got: %q", prompt) + if !containsStr(prompt, "MUST INCLUDE: q") { + t.Fatalf("expected required-field summary in prompt, got: %q", prompt) } } diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index fef1194..6b3a601 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -90,8 +90,7 @@ func buildClaudeToolPrompt(tools []any) string { continue } names = append(names, name) - schema, _ := json.Marshal(schemaObj) - toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) + toolSchemas = append(toolSchemas, util.FormatToolSchemaAttentionBlock(name, desc, schemaObj)) } if len(toolSchemas) == 0 { return "" diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index c11a3c7..5942d69 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -2,7 +2,6 @@ package openai import ( "encoding/json" - "fmt" "strings" "github.com/google/uuid" @@ -44,11 +43,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh continue } names = append(names, name) - if desc == "" { - desc = "No description available" - } - b, _ := json.Marshal(schema) - toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b))) + toolSchemas = append(toolSchemas, util.FormatToolSchemaAttentionBlock(name, desc, schema)) } if len(toolSchemas) == 0 { return messages, names diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go index 223689b..7539763 100644 --- a/internal/adapter/openai/prompt_build_test.go +++ b/internal/adapter/openai/prompt_build_test.go @@ -34,7 +34,11 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes "name": "get_weather", "description": "Get weather", "parameters": map[string]any{ - "type": "object", + "type": "object", + "required": []any{"city"}, + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, }, }, }, @@ -53,6 +57,9 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes if !strings.Contains(finalPrompt, "get_weather") { t.Fatalf("handler finalPrompt should include tool name history: %q", finalPrompt) } + if !strings.Contains(finalPrompt, "MUST INCLUDE: city") { + t.Fatalf("handler finalPrompt should front-load required fields: %q", finalPrompt) + } } func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) { @@ -67,7 +74,11 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * "name": "search", "description": "search docs", "parameters": map[string]any{ - "type": "object", + "type": "object", + "required": []any{"query"}, + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, }, }, }, @@ -83,6 +94,9 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") { t.Fatalf("vercel prepare finalPrompt missing xml format instruction: %q", finalPrompt) } + if !strings.Contains(finalPrompt, "MUST INCLUDE: query") { + t.Fatalf("vercel prepare finalPrompt missing required-field summary: %q", finalPrompt) + } if !strings.Contains(finalPrompt, "Do NOT wrap the XML in markdown code fences") { t.Fatalf("vercel prepare finalPrompt missing no-fence xml instruction: %q", finalPrompt) } diff --git a/internal/util/tool_prompt.go b/internal/util/tool_prompt.go index a801286..8d9fd59 100644 --- a/internal/util/tool_prompt.go +++ b/internal/util/tool_prompt.go @@ -1,6 +1,11 @@ package util -import "strings" +import ( + "encoding/json" + "fmt" + "sort" + "strings" +) // BuildToolCallInstructions generates the unified tool-calling instruction block // used by all adapters (OpenAI, Claude, Gemini). It uses attention-optimized @@ -54,12 +59,17 @@ RULES: 6) Parameters MUST use the exact field names from the selected tool schema. 7) CRITICAL: Do NOT invent or add any extra fields (such as "_raw", "_xml"). Use ONLY the fields strictly defined in the schema. Extra fields will cause execution failure. +ATTENTION CHECKLIST BEFORE YOU EMIT A TOOL CALL: +- Read the tool block above first. +- If the tool block says MUST INCLUDE, every such field must be present. +- If any required field is missing or uncertain, ask a clarifying question instead of guessing. + ❌ WRONG — Do NOT do these: Wrong 1 — mixed text and XML: I'll read the file for you. ... Wrong 2 — describing tool calls in text: [调用 Bash] {"command": "ls"} -Wrong 3 — missing wrapper: +Wrong 3 — empty or missing required parameters: ` + ex1 + `{} Wrong 4 — extra/invented fields: {"_raw": "...", "command": "ls"} @@ -98,6 +108,40 @@ Example C — Tool with complex nested JSON parameters: Remember: Output ONLY the ... XML block when calling tools.` } +// FormatToolSchemaAttentionBlock renders a compact, attention-friendly tool +// summary for prompt injection. It front-loads required fields so the model can +// spot them before the full format rules and examples. +func FormatToolSchemaAttentionBlock(name, description string, schema any) string { + lines := make([]string, 0, 4) + + name = strings.TrimSpace(name) + if name != "" { + lines = append(lines, "Tool: "+name) + } + + description = strings.TrimSpace(description) + if description != "" { + lines = append(lines, "Description: "+description) + } + + required, optional := summarizeToolSchemaFields(schema) + switch { + case len(required) > 0: + lines = append(lines, "MUST INCLUDE: "+strings.Join(required, ", ")) + if len(optional) > 0 { + lines = append(lines, "OPTIONAL: "+strings.Join(optional, ", ")) + } + case len(optional) > 0: + lines = append(lines, "FIELDS: "+strings.Join(optional, ", ")) + case schema != nil: + if b, err := json.Marshal(schema); err == nil && len(b) > 0 { + lines = append(lines, "Schema: "+string(b)) + } + } + + return strings.TrimSpace(strings.Join(lines, "\n")) +} + func matchAny(name string, candidates ...string) bool { for _, c := range candidates { if name == c { @@ -141,3 +185,75 @@ func exampleInteractiveParams(name string) string { return `{"question":"Which approach do you prefer?","follow_up":[{"text":"Option A"},{"text":"Option B"}]}` } } + +func summarizeToolSchemaFields(schema any) (required []string, optional []string) { + obj, ok := schema.(map[string]any) + if !ok || len(obj) == 0 { + return nil, nil + } + + requiredSet := map[string]struct{}{} + for _, name := range anySliceToStrings(obj["required"]) { + requiredSet[name] = struct{}{} + } + + propNames := map[string]struct{}{} + if props, ok := obj["properties"].(map[string]any); ok { + for k := range props { + name := strings.TrimSpace(k) + if name == "" { + continue + } + propNames[name] = struct{}{} + } + } + + required = make([]string, 0, len(requiredSet)) + for name := range requiredSet { + required = append(required, name) + } + sort.Strings(required) + + if len(propNames) == 0 { + return required, nil + } + + optional = make([]string, 0, len(propNames)) + for name := range propNames { + if _, ok := requiredSet[name]; ok { + continue + } + optional = append(optional, name) + } + sort.Strings(optional) + return required, optional +} + +func anySliceToStrings(v any) []string { + switch x := v.(type) { + case []string: + out := make([]string, 0, len(x)) + for _, item := range x { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(x)) + for _, item := range x { + s := strings.TrimSpace(fmt.Sprintf("%v", item)) + if s != "" && s != "" { + out = append(out, s) + } + } + return out + default: + s := strings.TrimSpace(fmt.Sprintf("%v", v)) + if s == "" || s == "" { + return nil + } + return []string{s} + } +} diff --git a/internal/util/tool_prompt_test.go b/internal/util/tool_prompt_test.go index e10f176..ddb90d6 100644 --- a/internal/util/tool_prompt_test.go +++ b/internal/util/tool_prompt_test.go @@ -24,3 +24,28 @@ func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T t.Fatalf("expected command parameter example for execute_command, got: %s", out) } } + +func TestFormatToolSchemaAttentionBlockPrioritizesRequiredFields(t *testing.T) { + schema := map[string]any{ + "type": "object", + "required": []any{ + "command", + }, + "properties": map[string]any{ + "command": map[string]any{"type": "string"}, + "cwd": map[string]any{"type": "string"}, + "timeout": map[string]any{"type": "integer"}, + }, + } + + out := FormatToolSchemaAttentionBlock("execute_command", "Run a command", schema) + if !strings.Contains(out, "Tool: execute_command") { + t.Fatalf("expected tool name in summary, got: %s", out) + } + if !strings.Contains(out, "MUST INCLUDE: command") { + t.Fatalf("expected required field summary, got: %s", out) + } + if !strings.Contains(out, "OPTIONAL: cwd, timeout") { + t.Fatalf("expected optional field summary, got: %s", out) + } +}