diff --git a/internal/httpapi/claude/handler_helpers_misc.go b/internal/httpapi/claude/handler_helpers_misc.go index 7b89734..6062dc6 100644 --- a/internal/httpapi/claude/handler_helpers_misc.go +++ b/internal/httpapi/claude/handler_helpers_misc.go @@ -1,6 +1,7 @@ package claude import ( + "ds2api/internal/toolcall" "fmt" "strings" ) @@ -31,30 +32,9 @@ func extractClaudeToolNames(tools []any) []string { } func extractClaudeToolMeta(m map[string]any) (string, string, any) { - name, _ := m["name"].(string) - desc, _ := m["description"].(string) - schemaObj := m["input_schema"] - if schemaObj == nil { - schemaObj = m["parameters"] - } - - if fn, ok := m["function"].(map[string]any); ok { - if strings.TrimSpace(name) == "" { - name, _ = fn["name"].(string) - } - if strings.TrimSpace(desc) == "" { - desc, _ = fn["description"].(string) - } - if schemaObj == nil { - if v, ok := fn["input_schema"]; ok { - schemaObj = v - } - } - if schemaObj == nil { - if v, ok := fn["parameters"]; ok { - schemaObj = v - } - } + name, desc, schemaObj := toolcall.ExtractToolMeta(m) + if strings.TrimSpace(desc) == "" { + desc = "No description available" } return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj } diff --git a/internal/httpapi/claude/handler_messages.go b/internal/httpapi/claude/handler_messages.go index ad8f54e..de47d28 100644 --- a/internal/httpapi/claude/handler_messages.go +++ b/internal/httpapi/claude/handler_messages.go @@ -177,7 +177,7 @@ func stripClaudeThinkingBlocks(raw []byte) []byte { return out } -func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -205,6 +205,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ searchEnabled, h.compatStripReferenceMarkers(), toolNames, + toolsRaw, ) streamRuntime.sendMessageStart() diff --git a/internal/httpapi/claude/standard_request.go b/internal/httpapi/claude/standard_request.go index 3f3e238..3f10723 100644 --- a/internal/httpapi/claude/standard_request.go +++ b/internal/httpapi/claude/standard_request.go @@ -53,6 +53,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma ResolvedModel: dsModel, ResponseModel: strings.TrimSpace(model), Messages: payload["messages"].([]any), + ToolsRaw: toolsRequested, FinalPrompt: finalPrompt, ToolNames: toolNames, Stream: util.ToBool(req["stream"]), diff --git a/internal/httpapi/claude/stream_runtime_core.go b/internal/httpapi/claude/stream_runtime_core.go index beb2d40..49fde53 100644 --- a/internal/httpapi/claude/stream_runtime_core.go +++ b/internal/httpapi/claude/stream_runtime_core.go @@ -18,6 +18,7 @@ type claudeStreamRuntime struct { model string toolNames []string messages []any + toolsRaw any thinkingEnabled bool searchEnabled bool @@ -47,6 +48,7 @@ func newClaudeStreamRuntime( searchEnabled bool, stripReferenceMarkers bool, toolNames []string, + toolsRaw any, ) *claudeStreamRuntime { return &claudeStreamRuntime{ w: w, @@ -59,6 +61,7 @@ func newClaudeStreamRuntime( bufferToolContent: len(toolNames) > 0, stripReferenceMarkers: stripReferenceMarkers, toolNames: toolNames, + toolsRaw: toolsRaw, messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), thinkingBlockIndex: -1, textBlockIndex: -1, diff --git a/internal/httpapi/claude/stream_runtime_finalize.go b/internal/httpapi/claude/stream_runtime_finalize.go index 241ff7a..32e9b5f 100644 --- a/internal/httpapi/claude/stream_runtime_finalize.go +++ b/internal/httpapi/claude/stream_runtime_finalize.go @@ -52,6 +52,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { detected = toolcall.ParseStandaloneToolCalls(finalThinking, s.toolNames) } if len(detected) > 0 { + detected = toolcall.NormalizeParsedToolCallsForSchemas(detected, s.toolsRaw) stopReason = "tool_use" for i, tc := range detected { idx := s.nextBlockIndex + i