diff --git a/internal/adapter/claude/handler_stream_test.go b/internal/adapter/claude/handler_stream_test.go index 701c8d7..42358aa 100644 --- a/internal/adapter/claude/handler_stream_test.go +++ b/internal/adapter/claude/handler_stream_test.go @@ -183,6 +183,32 @@ func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) { } } +func TestHandleClaudeStreamRealtimeToolDetectionFromThinkingFallback(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, + `data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"}) + + frames := parseClaudeFrames(t, rec.Body.String()) + foundToolUse := false + for _, f := range findClaudeFrames(frames, "content_block_start") { + contentBlock, _ := f.Payload["content_block"].(map[string]any) + if contentBlock["type"] == "tool_use" && contentBlock["name"] == "search" { + foundToolUse = true + break + } + } + if !foundToolUse { + t.Fatalf("expected tool_use block from thinking fallback, body=%s", rec.Body.String()) + } +} + func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) { h := &Handler{} resp := makeClaudeSSEHTTPResponse( diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index ae75d8e..f5b0ad5 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -141,6 +141,34 @@ func TestBuildClaudeToolPromptMultipleTools(t *testing.T) { } } +func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) { + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "Search via function tool", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "q": map[string]any{"type": "string"}, + }, + }, + }, + }, + } + prompt := buildClaudeToolPrompt(tools) + if !containsStr(prompt, "Tool: search") { + t.Fatalf("expected OpenAI-style function tool name in prompt, got: %q", prompt) + } + if !containsStr(prompt, "Search via function tool") { + t.Fatalf("expected OpenAI-style function tool description in prompt, got: %q", prompt) + } + if !containsStr(prompt, "\"q\"") { + t.Fatalf("expected parameters schema serialized in prompt, got: %q", prompt) + } +} + func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) { tools := []any{"not a map"} prompt := buildClaudeToolPrompt(tools) @@ -237,6 +265,21 @@ func TestExtractClaudeToolNamesNil(t *testing.T) { } } +func TestExtractClaudeToolNamesSupportsOpenAIStyleFunctionTool(t *testing.T) { + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "search" { + t.Fatalf("expected [search], got %v", names) + } +} + // ─── toMessageMaps ─────────────────────────────────────────────────── func TestToMessageMapsNormal(t *testing.T) { diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index df4c6b2..3728ffb 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -46,9 +46,8 @@ func buildClaudeToolPrompt(tools []any) string { if !ok { continue } - name, _ := m["name"].(string) - desc, _ := m["description"].(string) - schema, _ := json.Marshal(m["input_schema"]) + name, desc, schemaObj := extractClaudeToolMeta(m) + schema, _ := json.Marshal(schemaObj) parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) } parts = append(parts, @@ -98,13 +97,43 @@ func extractClaudeToolNames(tools []any) []string { if !ok { continue } - if name, ok := m["name"].(string); ok && name != "" { + name, _, _ := extractClaudeToolMeta(m) + if name != "" { out = append(out, name) } } return out } +func extractClaudeToolMeta(m map[string]any) (string, string, any) { + name, _ := m["name"].(string) + desc, _ := m["description"].(string) + schemaObj := m["input_schema"] + if schemaObj == nil { + schemaObj = m["parameters"] + } + + if fn, ok := m["function"].(map[string]any); ok { + if strings.TrimSpace(name) == "" { + name, _ = fn["name"].(string) + } + if strings.TrimSpace(desc) == "" { + desc, _ = fn["description"].(string) + } + if schemaObj == nil { + if v, ok := fn["input_schema"]; ok { + schemaObj = v + } + } + if schemaObj == nil { + if v, ok := fn["parameters"]; ok { + schemaObj = v + } + } + } + return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj +} + func toMessageMaps(v any) []map[string]any { arr, ok := v.([]any) if !ok { diff --git a/internal/adapter/claude/stream_runtime_finalize.go b/internal/adapter/claude/stream_runtime_finalize.go index f957ba1..12d9510 100644 --- a/internal/adapter/claude/stream_runtime_finalize.go +++ b/internal/adapter/claude/stream_runtime_finalize.go @@ -46,6 +46,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { if s.bufferToolContent { detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) == 0 && finalThinking != "" { + detected = util.ParseToolCalls(finalThinking, s.toolNames) + } if len(detected) > 0 { stopReason = "tool_use" for i, tc := range detected { diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go index fdba055..4675398 100644 --- a/internal/format/claude/render.go +++ b/internal/format/claude/render.go @@ -9,6 +9,9 @@ import ( func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && finalThinking != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } content := make([]map[string]any, 0, 4) if finalThinking != "" { content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) diff --git a/internal/format/claude/render_test.go b/internal/format/claude/render_test.go new file mode 100644 index 0000000..389eaee --- /dev/null +++ b/internal/format/claude/render_test.go @@ -0,0 +1,29 @@ +package claude + +import "testing" + +func TestBuildMessageResponseDetectsToolCallsFromThinkingFallback(t *testing.T) { + resp := BuildMessageResponse( + "msg_1", + "claude-sonnet-4-5", + []any{map[string]any{"role": "user", "content": "hi"}}, + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + "", + []string{"search"}, + ) + + if resp["stop_reason"] != "tool_use" { + t.Fatalf("expected stop_reason=tool_use, got=%#v", resp["stop_reason"]) + } + content, _ := resp["content"].([]map[string]any) + if len(content) < 2 { + t.Fatalf("expected thinking + tool_use content blocks, got=%#v", resp["content"]) + } + last := content[len(content)-1] + if last["type"] != "tool_use" { + t.Fatalf("expected last content block tool_use, got=%#v", last["type"]) + } + if last["name"] != "search" { + t.Fatalf("expected tool name search, got=%#v", last["name"]) + } +} diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index 5b386c2..fdace8e 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -89,8 +89,17 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) { allowed := map[string]struct{}{} + allowedCanonical := map[string]string{} for _, name := range availableToolNames { - allowed[name] = struct{}{} + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + allowed[trimmed] = struct{}{} + lower := strings.ToLower(trimmed) + if _, exists := allowedCanonical[lower]; !exists { + allowedCanonical[lower] = trimmed + } } if len(allowed) == 0 { rejectedSet := map[string]struct{}{} @@ -112,10 +121,17 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin if tc.Name == "" { continue } - if _, ok := allowed[tc.Name]; !ok { + matchedName := "" + if _, ok := allowed[tc.Name]; ok { + matchedName = tc.Name + } else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok { + matchedName = canonical + } + if matchedName == "" { rejectedSet[tc.Name] = struct{}{} continue } + tc.Name = matchedName if tc.Input == nil { tc.Input = map[string]any{} } diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 0e823c0..1287102 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -46,6 +46,17 @@ func TestParseToolCallsRejectsUnknownToolName(t *testing.T) { } } +func TestParseToolCallsAllowsCaseInsensitiveToolNameAndCanonicalizes(t *testing.T) { + text := `{"tool_calls":[{"name":"Bash","input":{"command":"ls -al"}}]}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } +} + func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) { text := `{"tool_calls":[{"name":"unknown","input":{}}]}` res := ParseToolCallsDetailed(text, []string{"search"})