From a550de30af77d54d36b77bed7a30ab3232fea318 Mon Sep 17 00:00:00 2001 From: shern-point Date: Wed, 29 Apr 2026 01:59:05 +0800 Subject: [PATCH] fix: expand shared tool schema extraction --- internal/promptcompat/tool_prompt.go | 8 +-- .../toolcall/toolcalls_schema_normalize.go | 40 ++++++++++----- .../toolcalls_schema_normalize_test.go | 49 +++++++++++++++++++ 3 files changed, 78 insertions(+), 19 deletions(-) diff --git a/internal/promptcompat/tool_prompt.go b/internal/promptcompat/tool_prompt.go index ba5f2cf..95d2f8b 100644 --- a/internal/promptcompat/tool_prompt.go +++ b/internal/promptcompat/tool_prompt.go @@ -30,13 +30,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy ToolChoiceP if !ok { continue } - fn, _ := tool["function"].(map[string]any) - if len(fn) == 0 { - fn = tool - } - name, _ := fn["name"].(string) - desc, _ := fn["description"].(string) - schema, _ := fn["parameters"].(map[string]any) + name, desc, schema := toolcall.ExtractToolMeta(tool) name = strings.TrimSpace(name) if !isAllowed(name) { continue diff --git a/internal/toolcall/toolcalls_schema_normalize.go b/internal/toolcall/toolcalls_schema_normalize.go index 44c772a..65a27c2 100644 --- a/internal/toolcall/toolcalls_schema_normalize.go +++ b/internal/toolcall/toolcalls_schema_normalize.go @@ -48,7 +48,7 @@ func buildToolSchemaIndex(toolsRaw any) map[string]any { if !ok { continue } - name, schema := extractToolNameAndSchema(tool) + name, _, schema := ExtractToolMeta(tool) if name == "" || schema == nil { continue } @@ -60,24 +60,31 @@ func buildToolSchemaIndex(toolsRaw any) map[string]any { return out } -func extractToolNameAndSchema(tool map[string]any) (string, any) { +func ExtractToolMeta(tool map[string]any) (string, string, any) { name := strings.TrimSpace(asStringValue(tool["name"])) - schema := tool["parameters"] - if schema == nil { - schema = tool["input_schema"] - } + desc := strings.TrimSpace(asStringValue(tool["description"])) + schema := firstNonNil( + tool["parameters"], + tool["input_schema"], + tool["inputSchema"], + tool["schema"], + ) if fn, ok := tool["function"].(map[string]any); ok { if name == "" { name = strings.TrimSpace(asStringValue(fn["name"])) } - if schema == nil { - schema = fn["parameters"] - } - if schema == nil { - schema = fn["input_schema"] + if desc == "" { + desc = strings.TrimSpace(asStringValue(fn["description"])) } + schema = firstNonNil( + schema, + fn["parameters"], + fn["input_schema"], + fn["inputSchema"], + fn["schema"], + ) } - return name, schema + return name, desc, schema } func normalizeToolValueWithSchema(value any, schema any) (any, bool) { @@ -264,3 +271,12 @@ func asStringValue(v any) string { } return "" } + +func firstNonNil(values ...any) any { + for _, value := range values { + if value != nil { + return value + } + } + return nil +} diff --git a/internal/toolcall/toolcalls_schema_normalize_test.go b/internal/toolcall/toolcalls_schema_normalize_test.go index 7807c3f..7dac106 100644 --- a/internal/toolcall/toolcalls_schema_normalize_test.go +++ b/internal/toolcall/toolcalls_schema_normalize_test.go @@ -110,3 +110,52 @@ func TestNormalizeParsedToolCallsForSchemasLeavesAmbiguousUnionUnchanged(t *test t.Fatalf("expected ambiguous union to stay unchanged, got %#v", got[0].Input["taskId"]) } } + +func TestNormalizeParsedToolCallsForSchemasSupportsCamelCaseInputSchema(t *testing.T) { + toolsRaw := []any{ + map[string]any{ + "name": "Write", + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + }, + }, + }, + } + calls := []ParsedToolCall{{Name: "Write", Input: map[string]any{"content": map[string]any{"message": "hi"}}}} + got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw) + if got[0].Input["content"] != `{"message":"hi"}` { + t.Fatalf("expected camelCase inputSchema content coercion, got %#v", got[0].Input["content"]) + } +} + +func TestNormalizeParsedToolCallsForSchemasPreservesArrayWhenSchemaSaysArray(t *testing.T) { + toolsRaw := []any{ + map[string]any{ + "name": "todowrite", + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "todos": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + "status": map[string]any{"type": "string"}, + "priority": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + }, + } + todos := []any{map[string]any{"content": "x", "status": "pending", "priority": "high"}} + calls := []ParsedToolCall{{Name: "todowrite", Input: map[string]any{"todos": todos}}} + got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw) + if !reflect.DeepEqual(got[0].Input["todos"], todos) { + t.Fatalf("expected todos array preserved, got %#v want %#v", got[0].Input["todos"], todos) + } +}