diff --git a/internal/toolcall/toolcalls_schema_normalize.go b/internal/toolcall/toolcalls_schema_normalize.go new file mode 100644 index 0000000..44c772a --- /dev/null +++ b/internal/toolcall/toolcalls_schema_normalize.go @@ -0,0 +1,266 @@ +package toolcall + +import ( + "encoding/json" + "strings" +) + +func NormalizeParsedToolCallsForSchemas(calls []ParsedToolCall, toolsRaw any) []ParsedToolCall { + if len(calls) == 0 { + return calls + } + schemas := buildToolSchemaIndex(toolsRaw) + if len(schemas) == 0 { + return calls + } + + var changedAny bool + out := make([]ParsedToolCall, len(calls)) + for i, call := range calls { + out[i] = call + schema, ok := schemas[strings.ToLower(strings.TrimSpace(call.Name))] + if !ok || call.Input == nil { + continue + } + normalized, changed := normalizeToolValueWithSchema(call.Input, schema) + if !changed { + continue + } + changedAny = true + if input, ok := normalized.(map[string]any); ok { + out[i].Input = input + } + } + if !changedAny { + return calls + } + return out +} + +func buildToolSchemaIndex(toolsRaw any) map[string]any { + tools, ok := toolsRaw.([]any) + if !ok || len(tools) == 0 { + return nil + } + out := make(map[string]any, len(tools)) + for _, item := range tools { + tool, ok := item.(map[string]any) + if !ok { + continue + } + name, schema := extractToolNameAndSchema(tool) + if name == "" || schema == nil { + continue + } + out[strings.ToLower(name)] = schema + } + if len(out) == 0 { + return nil + } + return out +} + +func extractToolNameAndSchema(tool map[string]any) (string, any) { + name := strings.TrimSpace(asStringValue(tool["name"])) + schema := tool["parameters"] + if schema == nil { + schema = tool["input_schema"] + } + if fn, ok := tool["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(asStringValue(fn["name"])) + } + if schema == nil { + schema = fn["parameters"] + } + if schema == nil { + schema = fn["input_schema"] + } + } + return name, schema +} + +func normalizeToolValueWithSchema(value any, schema any) (any, bool) { + if value == nil || schema == nil { + return value, false + } + schemaMap, ok := schema.(map[string]any) + if !ok || len(schemaMap) == 0 { + return value, false + } + if shouldCoerceSchemaToString(schemaMap) { + return stringifySchemaValue(value) + } + if looksLikeObjectSchema(schemaMap) { + obj, ok := value.(map[string]any) + if !ok || len(obj) == 0 { + return value, false + } + properties, _ := schemaMap["properties"].(map[string]any) + additional := schemaMap["additionalProperties"] + changed := false + out := make(map[string]any, len(obj)) + for key, current := range obj { + next := current + var fieldChanged bool + if propSchema, ok := properties[key]; ok { + next, fieldChanged = normalizeToolValueWithSchema(current, propSchema) + } else if additional != nil { + next, fieldChanged = normalizeToolValueWithSchema(current, additional) + } + out[key] = next + changed = changed || fieldChanged + } + if !changed { + return value, false + } + return out, true + } + if looksLikeArraySchema(schemaMap) { + arr, ok := value.([]any) + if !ok || len(arr) == 0 { + return value, false + } + itemsSchema := schemaMap["items"] + if itemsSchema == nil { + return value, false + } + changed := false + out := make([]any, len(arr)) + switch itemSchemas := itemsSchema.(type) { + case []any: + for i, item := range arr { + if i >= len(itemSchemas) { + out[i] = item + continue + } + next, itemChanged := normalizeToolValueWithSchema(item, itemSchemas[i]) + out[i] = next + changed = changed || itemChanged + } + default: + for i, item := range arr { + next, itemChanged := normalizeToolValueWithSchema(item, itemsSchema) + out[i] = next + changed = changed || itemChanged + } + } + if !changed { + return value, false + } + return out, true + } + return value, false +} + +func shouldCoerceSchemaToString(schema map[string]any) bool { + if schema == nil { + return false + } + if isStringConst(schema["const"]) { + return true + } + if isStringEnum(schema["enum"]) { + return true + } + switch v := schema["type"].(type) { + case string: + return strings.EqualFold(strings.TrimSpace(v), "string") + case []any: + return isOnlyStringLikeTypes(v) + case []string: + items := make([]any, 0, len(v)) + for _, item := range v { + items = append(items, item) + } + return isOnlyStringLikeTypes(items) + default: + return false + } +} + +func looksLikeObjectSchema(schema map[string]any) bool { + if schema == nil { + return false + } + if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "object") { + return true + } + if _, ok := schema["properties"].(map[string]any); ok { + return true + } + _, hasAdditional := schema["additionalProperties"] + return hasAdditional +} + +func looksLikeArraySchema(schema map[string]any) bool { + if schema == nil { + return false + } + if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "array") { + return true + } + _, hasItems := schema["items"] + return hasItems +} + +func isOnlyStringLikeTypes(values []any) bool { + if len(values) == 0 { + return false + } + hasString := false + for _, item := range values { + typ, ok := item.(string) + if !ok { + return false + } + switch strings.ToLower(strings.TrimSpace(typ)) { + case "string": + hasString = true + case "null": + continue + default: + return false + } + } + return hasString +} + +func isStringConst(v any) bool { + _, ok := v.(string) + return ok +} + +func isStringEnum(v any) bool { + values, ok := v.([]any) + if !ok || len(values) == 0 { + return false + } + for _, item := range values { + if _, ok := item.(string); !ok { + return false + } + } + return true +} + +func stringifySchemaValue(value any) (any, bool) { + if value == nil { + return value, false + } + if s, ok := value.(string); ok { + return s, false + } + b, err := json.Marshal(value) + if err != nil { + return value, false + } + return string(b), true +} + +func asStringValue(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} diff --git a/internal/toolcall/toolcalls_schema_normalize_test.go b/internal/toolcall/toolcalls_schema_normalize_test.go new file mode 100644 index 0000000..7807c3f --- /dev/null +++ b/internal/toolcall/toolcalls_schema_normalize_test.go @@ -0,0 +1,112 @@ +package toolcall + +import ( + "reflect" + "testing" +) + +func TestNormalizeParsedToolCallsForSchemasCoercesDeclaredStringFieldsRecursively(t *testing.T) { + toolsRaw := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "TaskUpdate", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "taskId": map[string]any{"type": "string"}, + "payload": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + "count": map[string]any{"type": "number"}, + }, + }, + }, + }, + }, + }, + } + calls := []ParsedToolCall{{ + Name: "TaskUpdate", + Input: map[string]any{ + "taskId": 1, + "payload": map[string]any{ + "content": map[string]any{"text": "hello"}, + "tags": []any{1, true, map[string]any{"k": "v"}}, + "count": 2, + }, + }, + }} + + got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw) + if len(got) != 1 { + t.Fatalf("expected one normalized call, got %#v", got) + } + if got[0].Input["taskId"] != "1" { + t.Fatalf("expected taskId coerced to string, got %#v", got[0].Input["taskId"]) + } + payload, ok := got[0].Input["payload"].(map[string]any) + if !ok { + t.Fatalf("expected payload object, got %#v", got[0].Input["payload"]) + } + if payload["content"] != `{"text":"hello"}` { + t.Fatalf("expected nested content coerced to json string, got %#v", payload["content"]) + } + if payload["count"] != 2 { + t.Fatalf("expected non-string count unchanged, got %#v", payload["count"]) + } + tags, ok := payload["tags"].([]any) + if !ok { + t.Fatalf("expected tags slice, got %#v", payload["tags"]) + } + wantTags := []any{"1", "true", `{"k":"v"}`} + if !reflect.DeepEqual(tags, wantTags) { + t.Fatalf("unexpected normalized tags: got %#v want %#v", tags, wantTags) + } +} + +func TestNormalizeParsedToolCallsForSchemasSupportsDirectToolSchemaShape(t *testing.T) { + toolsRaw := []any{ + map[string]any{ + "name": "Write", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{"type": "string"}, + }, + }, + }, + } + calls := []ParsedToolCall{{Name: "Write", Input: map[string]any{"content": []any{"a", 1}}}} + got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw) + if got[0].Input["content"] != `["a",1]` { + t.Fatalf("expected direct-schema content coerced to string, got %#v", got[0].Input["content"]) + } +} + +func TestNormalizeParsedToolCallsForSchemasLeavesAmbiguousUnionUnchanged(t *testing.T) { + toolsRaw := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "TaskUpdate", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "taskId": map[string]any{"type": []any{"string", "integer"}}, + }, + }, + }, + }, + } + calls := []ParsedToolCall{{Name: "TaskUpdate", Input: map[string]any{"taskId": 1}}} + got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw) + if got[0].Input["taskId"] != 1 { + t.Fatalf("expected ambiguous union to stay unchanged, got %#v", got[0].Input["taskId"]) + } +}