refactor: rename tool XML wrapper from tool_calls to tool_batch and add schema attention blocks to tool prompts

This commit is contained in:
CJACK
2026-04-05 19:22:43 +08:00
parent b8e9ca2028
commit ade648033d
6 changed files with 165 additions and 15 deletions

View File

@@ -319,7 +319,8 @@ func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) {
"name": "search",
"description": "Search via function tool",
"parameters": map[string]any{
"type": "object",
"type": "object",
"required": []any{"q"},
"properties": map[string]any{
"q": map[string]any{"type": "string"},
},
@@ -334,8 +335,8 @@ func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) {
if !containsStr(prompt, "Search via function tool") {
t.Fatalf("expected OpenAI-style function tool description in prompt, got: %q", prompt)
}
if !containsStr(prompt, "\"q\"") {
t.Fatalf("expected parameters schema serialized in prompt, got: %q", prompt)
if !containsStr(prompt, "MUST INCLUDE: q") {
t.Fatalf("expected required-field summary in prompt, got: %q", prompt)
}
}

View File

@@ -90,8 +90,7 @@ func buildClaudeToolPrompt(tools []any) string {
continue
}
names = append(names, name)
schema, _ := json.Marshal(schemaObj)
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
toolSchemas = append(toolSchemas, util.FormatToolSchemaAttentionBlock(name, desc, schemaObj))
}
if len(toolSchemas) == 0 {
return ""

View File

@@ -2,7 +2,6 @@ package openai
import (
"encoding/json"
"fmt"
"strings"
"github.com/google/uuid"
@@ -44,11 +43,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh
continue
}
names = append(names, name)
if desc == "" {
desc = "No description available"
}
b, _ := json.Marshal(schema)
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
toolSchemas = append(toolSchemas, util.FormatToolSchemaAttentionBlock(name, desc, schema))
}
if len(toolSchemas) == 0 {
return messages, names

View File

@@ -34,7 +34,11 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
"name": "get_weather",
"description": "Get weather",
"parameters": map[string]any{
"type": "object",
"type": "object",
"required": []any{"city"},
"properties": map[string]any{
"city": map[string]any{"type": "string"},
},
},
},
},
@@ -53,6 +57,9 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
if !strings.Contains(finalPrompt, "<tool_name>get_weather</tool_name>") {
t.Fatalf("handler finalPrompt should include tool name history: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "MUST INCLUDE: city") {
t.Fatalf("handler finalPrompt should front-load required fields: %q", finalPrompt)
}
}
func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) {
@@ -67,7 +74,11 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
"name": "search",
"description": "search docs",
"parameters": map[string]any{
"type": "object",
"type": "object",
"required": []any{"query"},
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
},
},
@@ -83,6 +94,9 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") {
t.Fatalf("vercel prepare finalPrompt missing xml format instruction: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "MUST INCLUDE: query") {
t.Fatalf("vercel prepare finalPrompt missing required-field summary: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "Do NOT wrap the XML in markdown code fences") {
t.Fatalf("vercel prepare finalPrompt missing no-fence xml instruction: %q", finalPrompt)
}

View File

@@ -1,6 +1,11 @@
package util
import "strings"
import (
"encoding/json"
"fmt"
"sort"
"strings"
)
// BuildToolCallInstructions generates the unified tool-calling instruction block
// used by all adapters (OpenAI, Claude, Gemini). It uses attention-optimized
@@ -54,12 +59,17 @@ RULES:
6) Parameters MUST use the exact field names from the selected tool schema.
7) CRITICAL: Do NOT invent or add any extra fields (such as "_raw", "_xml"). Use ONLY the fields strictly defined in the schema. Extra fields will cause execution failure.
ATTENTION CHECKLIST BEFORE YOU EMIT A TOOL CALL:
- Read the tool block above first.
- If the tool block says MUST INCLUDE, every such field must be present.
- If any required field is missing or uncertain, ask a clarifying question instead of guessing.
❌ WRONG — Do NOT do these:
Wrong 1 — mixed text and XML:
I'll read the file for you. <tool_calls><tool_call>...
Wrong 2 — describing tool calls in text:
[调用 Bash] {"command": "ls"}
Wrong 3 — missing <tool_calls> wrapper:
Wrong 3 — empty or missing required parameters:
<tool_call><tool_name>` + ex1 + `</tool_name><parameters>{}</parameters></tool_call>
Wrong 4 — extra/invented fields:
<parameters>{"_raw": "...", "command": "ls"}</parameters>
@@ -98,6 +108,40 @@ Example C — Tool with complex nested JSON parameters:
Remember: Output ONLY the <tool_calls>...</tool_calls> XML block when calling tools.`
}
// FormatToolSchemaAttentionBlock renders a compact, attention-friendly tool
// summary for prompt injection. It front-loads required fields so the model can
// spot them before the full format rules and examples.
func FormatToolSchemaAttentionBlock(name, description string, schema any) string {
lines := make([]string, 0, 4)
name = strings.TrimSpace(name)
if name != "" {
lines = append(lines, "Tool: "+name)
}
description = strings.TrimSpace(description)
if description != "" {
lines = append(lines, "Description: "+description)
}
required, optional := summarizeToolSchemaFields(schema)
switch {
case len(required) > 0:
lines = append(lines, "MUST INCLUDE: "+strings.Join(required, ", "))
if len(optional) > 0 {
lines = append(lines, "OPTIONAL: "+strings.Join(optional, ", "))
}
case len(optional) > 0:
lines = append(lines, "FIELDS: "+strings.Join(optional, ", "))
case schema != nil:
if b, err := json.Marshal(schema); err == nil && len(b) > 0 {
lines = append(lines, "Schema: "+string(b))
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
func matchAny(name string, candidates ...string) bool {
for _, c := range candidates {
if name == c {
@@ -141,3 +185,75 @@ func exampleInteractiveParams(name string) string {
return `{"question":"Which approach do you prefer?","follow_up":[{"text":"Option A"},{"text":"Option B"}]}`
}
}
func summarizeToolSchemaFields(schema any) (required []string, optional []string) {
obj, ok := schema.(map[string]any)
if !ok || len(obj) == 0 {
return nil, nil
}
requiredSet := map[string]struct{}{}
for _, name := range anySliceToStrings(obj["required"]) {
requiredSet[name] = struct{}{}
}
propNames := map[string]struct{}{}
if props, ok := obj["properties"].(map[string]any); ok {
for k := range props {
name := strings.TrimSpace(k)
if name == "" {
continue
}
propNames[name] = struct{}{}
}
}
required = make([]string, 0, len(requiredSet))
for name := range requiredSet {
required = append(required, name)
}
sort.Strings(required)
if len(propNames) == 0 {
return required, nil
}
optional = make([]string, 0, len(propNames))
for name := range propNames {
if _, ok := requiredSet[name]; ok {
continue
}
optional = append(optional, name)
}
sort.Strings(optional)
return required, optional
}
func anySliceToStrings(v any) []string {
switch x := v.(type) {
case []string:
out := make([]string, 0, len(x))
for _, item := range x {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(x))
for _, item := range x {
s := strings.TrimSpace(fmt.Sprintf("%v", item))
if s != "" && s != "<nil>" {
out = append(out, s)
}
}
return out
default:
s := strings.TrimSpace(fmt.Sprintf("%v", v))
if s == "" || s == "<nil>" {
return nil
}
return []string{s}
}
}

View File

@@ -24,3 +24,28 @@ func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T
t.Fatalf("expected command parameter example for execute_command, got: %s", out)
}
}
func TestFormatToolSchemaAttentionBlockPrioritizesRequiredFields(t *testing.T) {
schema := map[string]any{
"type": "object",
"required": []any{
"command",
},
"properties": map[string]any{
"command": map[string]any{"type": "string"},
"cwd": map[string]any{"type": "string"},
"timeout": map[string]any{"type": "integer"},
},
}
out := FormatToolSchemaAttentionBlock("execute_command", "Run a command", schema)
if !strings.Contains(out, "Tool: execute_command") {
t.Fatalf("expected tool name in summary, got: %s", out)
}
if !strings.Contains(out, "MUST INCLUDE: command") {
t.Fatalf("expected required field summary, got: %s", out)
}
if !strings.Contains(out, "OPTIONAL: cwd, timeout") {
t.Fatalf("expected optional field summary, got: %s", out)
}
}