diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index f4006ed..d9a1ba4 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -25,10 +25,11 @@ type chatStreamRuntime struct { thinkingEnabled bool searchEnabled bool - firstChunkSent bool - bufferToolContent bool - emitEarlyToolDeltas bool - toolCallsEmitted bool + firstChunkSent bool + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + toolCallsDoneEmitted bool toolSieve toolStreamSieveState streamToolCallIDs map[int]string @@ -96,7 +97,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) { finalThinking := s.thinking.String() finalText := s.text.String() detected := util.ParseToolCalls(finalText, s.toolNames) - if len(detected) > 0 && !s.toolCallsEmitted { + if len(detected) > 0 && !s.toolCallsDoneEmitted { finishReason = "tool_calls" delta := map[string]any{ "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs), @@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) { []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, nil, )) + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true } else if s.bufferToolContent { for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) { + if len(evt.ToolCalls) > 0 { + finishReason = "tool_calls" + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)}, + nil, + )) + } if evt.Content == "" { continue } @@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !s.emitEarlyToolDeltas { continue } - s.toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs), + formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) + if len(formatted) == 0 { + continue } + tcDelta := map[string]any{ + "tool_calls": formatted, + } + s.toolCallsEmitted = true if !s.firstChunkSent { tcDelta["role"] = "assistant" s.firstChunkSent = true @@ -202,6 +228,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD } if len(evt.ToolCalls) > 0 { s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true tcDelta := map[string]any{ "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs), } diff --git a/internal/adapter/openai/deps_injection_test.go b/internal/adapter/openai/deps_injection_test.go index baa0c11..6286c0c 100644 --- a/internal/adapter/openai/deps_injection_test.go +++ b/internal/adapter/openai/deps_injection_test.go @@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) { "model": "my-model", "messages": []any{map[string]any{"role": "user", "content": "hello"}}, } - out, err := normalizeOpenAIChatRequest(cfg, req) + out, err := normalizeOpenAIChatRequest(cfg, req, "") if err != nil { t.Fatalf("normalizeOpenAIChatRequest error: %v", err) } @@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing. _, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ aliases: map[string]string{}, wideInput: false, - }, req) + }, req, "") if err == nil { t.Fatal("expected error when wide input is disabled and only input is provided") } @@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing. out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ aliases: map[string]string{}, wideInput: true, - }, req) + }, req, "") if err != nil { t.Fatalf("unexpected error when wide input is enabled: %v", err) } diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index e04f9e8..517c88a 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index dd2bb0f..cf2420e 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -735,3 +735,71 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } + +func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`, + `data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + + foundSearch := false + foundEval := false + foundIndex1 := false + maxToolCallsInDelta := 0 + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + toolCalls, _ := delta["tool_calls"].([]any) + if len(toolCalls) > maxToolCallsInDelta { + maxToolCallsInDelta = len(toolCalls) + } + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 { + foundIndex1 = true + } + fn, _ := tcm["function"].(map[string]any) + name, _ := fn["name"].(string) + switch name { + case "search_web": + foundSearch = true + case "eval_javascript": + foundEval = true + case "search_webeval_javascript": + t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String()) + } + if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String()) + } + } + } + } + if !foundSearch || !foundEval { + t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String()) + } + if maxToolCallsInDelta != 2 { + t.Fatalf("expected one tool_calls delta containing exactly two calls, max=%d body=%s", maxToolCallsInDelta, rec.Body.String()) + } + if !foundIndex1 { + t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String()) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index c0ab7d2..a767960 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -4,9 +4,11 @@ import ( "encoding/json" "fmt" "strings" + + "ds2api/internal/config" ) -func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { +func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any { out := make([]map[string]any, 0, len(raw)) for _, item := range raw { msg, ok := item.(map[string]any) @@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { switch role { case "assistant": content := normalizeOpenAIContentForPrompt(msg["content"]) - toolCalls := formatAssistantToolCallsForPrompt(msg) + toolCalls := formatAssistantToolCallsForPrompt(msg, traceID) combined := joinNonEmpty(content, toolCalls) if combined == "" { continue @@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { return out } -func formatAssistantToolCallsForPrompt(msg map[string]any) string { +func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string { entries := make([]string, 0) if calls, ok := msg["tool_calls"].([]any); ok { for i, item := range calls { @@ -86,6 +88,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string { if args == "" { args = "{}" } + maybeWarnSuspiciousToolHistory(traceID, id, name, args) entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args)) } } @@ -99,6 +102,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string { if args == "" { args = "{}" } + maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args) entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args)) } @@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string { } return strings.Join(nonEmpty, "\n\n") } + +func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) { + if !looksLikeConcatenatedJSON(args) { + return + } + traceID = strings.TrimSpace(traceID) + if traceID == "" { + traceID = "unknown" + } + config.Logger.Warn( + "[openai] suspicious tool call history payload detected", + "trace_id", traceID, + "tool_call_id", strings.TrimSpace(callID), + "name", strings.TrimSpace(name), + "arguments_preview", previewToolArgs(args, 160), + ) +} + +func looksLikeConcatenatedJSON(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") { + return true + } + dec := json.NewDecoder(strings.NewReader(trimmed)) + var first any + if err := dec.Decode(&first); err != nil { + return false + } + var second any + return dec.Decode(&second) == nil +} + +func previewToolArgs(raw string, max int) string { + trimmed := strings.TrimSpace(raw) + if max <= 0 || len(trimmed) <= max { + return trimmed + } + return trimmed[:max] +} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index 27849d7..30403bc 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -33,7 +33,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") if len(normalized) != 4 { t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) } @@ -68,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing. }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") got, _ := normalized[0]["content"].(string) if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) { t.Fatalf("expected serialized object in tool content, got %q", got) @@ -89,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") got, _ := normalized[0]["content"].(string) if !strings.Contains(got, "line-1\nline-2") { t.Fatalf("expected joined text blocks, got %q", got) @@ -108,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") if len(normalized) != 1 { t.Fatalf("expected one normalized message, got %d", len(normalized)) } @@ -120,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { t.Fatalf("unexpected normalized function-role content: %q", got) } } + +func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_search", + "type": "function", + "function": map[string]any{ + "name": "search_web", + "arguments": `{"query":"latest ai news"}`, + }, + }, + map[string]any{ + "id": "call_eval", + "type": "function", + "function": map[string]any{ + "name": "eval_javascript", + "arguments": `{"code":"1+1"}`, + }, + }, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized assistant message, got %d", len(normalized)) + } + content, _ := normalized[0]["content"].(string) + if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 { + t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content) + } + if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") { + t.Fatalf("missing first tool call block, got %q", content) + } + if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") { + t.Fatalf("missing second tool call block, got %q", content) + } + if strings.Contains(content, "search_webeval_javascript") { + t.Fatalf("unexpected merged function name detected: %q", content) + } + if strings.Contains(content, `}{"`) { + t.Fatalf("unexpected concatenated function arguments detected: %q", content) + } +} diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index f83963f..890e3dc 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -4,8 +4,8 @@ import ( "ds2api/internal/deepseek" ) -func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { - messages := normalizeOpenAIMessagesForPrompt(messagesRaw) +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID) toolNames := []string{} if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { messages, toolNames = injectToolPrompt(messages, tools) diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go index 878af73..bd6223e 100644 --- a/internal/adapter/openai/prompt_build_test.go +++ b/internal/adapter/openai/prompt_build_test.go @@ -40,7 +40,7 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes }, } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "") if len(toolNames) != 1 || toolNames[0] != "get_weather" { t.Fatalf("unexpected tool names: %#v", toolNames) } @@ -70,7 +70,7 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * }, } - finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools) + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "") if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") { t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 522521d..bd9ff3a 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req) + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r)) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go index 11c64ce..d059ca1 100644 --- a/internal/adapter/openai/responses_stream_runtime.go +++ b/internal/adapter/openai/responses_stream_runtime.go @@ -39,6 +39,7 @@ type responsesStreamRuntime struct { streamToolCallIDs map[int]string streamFunctionIDs map[int]string functionDone map[int]bool + toolCallsDoneSigs map[string]bool reasoningItemID string persistResponse func(obj map[string]any) @@ -73,6 +74,7 @@ func newResponsesStreamRuntime( streamToolCallIDs: map[int]string{}, streamFunctionIDs: map[int]string{}, functionDone: map[int]bool{}, + toolCallsDoneSigs: map[string]bool{}, persistResponse: persistResponse, } } @@ -106,25 +108,8 @@ func (s *responsesStreamRuntime) finalize() { s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking)) } if s.bufferToolContent { - for _, evt := range flushToolSieve(&s.sieve, s.toolNames) { - if evt.Content != "" { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) - } - if len(evt.ToolCalls) > 0 { - s.toolCallsEmitted = true - s.toolCallsDoneEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs))) - s.emitFunctionCallDoneEvents(evt.ToolCalls) - } - } - for _, evt := range flushToolSieve(&s.thinkingSieve, s.toolNames) { - if len(evt.ToolCalls) > 0 { - s.toolCallsEmitted = true - s.toolCallsDoneEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs))) - s.emitFunctionCallDoneEvents(evt.ToolCalls) - } - } + s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) + s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) } // Compatibility fallback: some streams only emit incremental tool deltas. // Ensure final function_call_arguments.done is emitted at least once. @@ -141,9 +126,10 @@ func (s *responsesStreamRuntime) finalize() { } if len(detected) > 0 { if !s.toolCallsDoneEmitted { - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs))) + s.emitToolCallsDone(detected) + } else { + s.emitFunctionCallDoneEvents(detected) } - s.emitFunctionCallDoneEvents(detected) } } @@ -186,22 +172,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text)) if s.bufferToolContent { - for _, evt := range processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames) { - if len(evt.ToolCallDeltas) > 0 { - if !s.emitEarlyToolDeltas { - continue - } - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs))) - s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) - } - if len(evt.ToolCalls) > 0 { - s.toolCallsEmitted = true - s.toolCallsDoneEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs))) - s.emitFunctionCallDoneEvents(evt.ToolCalls) - } - } + s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) } continue } @@ -211,30 +182,56 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) continue } - for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) { - if evt.Content != "" { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) - } - if len(evt.ToolCallDeltas) > 0 { - if !s.emitEarlyToolDeltas { - continue - } - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs))) - s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) - } - if len(evt.ToolCalls) > 0 { - s.toolCallsEmitted = true - s.toolCallsDoneEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs))) - s.emitFunctionCallDoneEvents(evt.ToolCalls) - } - } + s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) } return streamengine.ParsedDecision{ContentSeen: contentSeen} } +func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) { + for _, evt := range events { + if emitContent && evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) + if len(formatted) == 0 { + continue + } + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted)) + s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) + } + if len(evt.ToolCalls) > 0 { + s.emitToolCallsDone(evt.ToolCalls) + } + } +} + +func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) { + if len(calls) == 0 { + return + } + sig := toolCallListSignature(calls) + if sig != "" && s.toolCallsDoneSigs[sig] { + return + } + if sig != "" { + s.toolCallsDoneSigs[sig] = true + } + formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs) + if len(formatted) == 0 { + return + } + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted)) + s.emitFunctionCallDoneEvents(calls) +} + func (s *responsesStreamRuntime) ensureReasoningItemID() string { if strings.TrimSpace(s.reasoningItemID) != "" { return s.reasoningItemID @@ -356,3 +353,20 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) } } } + +func toolCallListSignature(calls []util.ParsedToolCall) string { + if len(calls) == 0 { + return "" + } + var b strings.Builder + for i, tc := range calls { + if i > 0 { + b.WriteString("|") + } + b.WriteString(strings.TrimSpace(tc.Name)) + b.WriteString(":") + args, _ := json.Marshal(tc.Input) + b.Write(args) + } + return b.String() +} diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index f938c44..a47903c 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -246,6 +246,141 @@ func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) } } +func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) + + sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}) + + body := rec.Body.String() + if !strings.Contains(body, "event: response.output_tool_call.done") { + t.Fatalf("expected response.output_tool_call.done event, body=%s", body) + } + donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") + if len(donePayloads) != 2 { + t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) + } + + seenNames := map[string]string{} + for _, payload := range donePayloads { + name := strings.TrimSpace(asString(payload["name"])) + callID := strings.TrimSpace(asString(payload["call_id"])) + args := strings.TrimSpace(asString(payload["arguments"])) + if callID == "" { + t.Fatalf("expected non-empty call_id in done payload: %#v", payload) + } + if strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload) + } + if name == "search_webeval_javascript" { + t.Fatalf("unexpected merged tool name in done payload: %#v", payload) + } + if name != "search_web" && name != "eval_javascript" { + t.Fatalf("unexpected tool name in done payload: %#v", payload) + } + seenNames[name] = callID + } + if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" { + t.Fatalf("expected done events for both tools, got %#v", seenNames) + } + if seenNames["search_web"] == seenNames["eval_javascript"] { + t.Fatalf("expected distinct call_id per tool, got %#v", seenNames) + } + + completed, ok := extractSSEEventPayload(body, "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", body) + } + responseObj, _ := completed["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + functionCallIDs := map[string]string{} + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil || m["type"] != "function_call" { + continue + } + name := strings.TrimSpace(asString(m["name"])) + callID := strings.TrimSpace(asString(m["call_id"])) + if name != "" && callID != "" { + functionCallIDs[name] = callID + } + } + if functionCallIDs["search_web"] != seenNames["search_web"] { + t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"]) + } + if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] { + t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"]) + } +} + +func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, v string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) + + sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"}) + + body := rec.Body.String() + if !strings.Contains(body, "event: response.reasoning_text.delta") { + t.Fatalf("expected reasoning stream events, body=%s", body) + } + donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") + if len(donePayloads) != 2 { + t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) + } + seen := map[string]bool{} + for _, payload := range donePayloads { + name := strings.TrimSpace(asString(payload["name"])) + if name == "search_webeval_javascript" { + t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload) + } + if name != "search_web" && name != "eval_javascript" { + t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload) + } + args := strings.TrimSpace(asString(payload["arguments"])) + if strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload) + } + seen[name] = true + } + if !seen["search_web"] || !seen["eval_javascript"] { + t.Fatalf("expected both tools in thinking channel done events, got %#v", seen) + } +} + func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { scanner := bufio.NewScanner(strings.NewReader(body)) matched := false @@ -271,3 +406,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { } return nil, false } + +func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any { + scanner := bufio.NewScanner(strings.NewReader(body)) + matched := false + out := make([]map[string]any, 0, 2) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "event: ") { + evt := strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + matched = evt == targetEvent + continue + } + if !matched || !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" || raw == "[DONE]" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + continue + } + out = append(out, payload) + } + return out +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index 5883d03..7683ee7 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -8,7 +8,7 @@ import ( "ds2api/internal/util" ) -func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) { model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { @@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St if responseModel == "" { responseModel = resolvedModel } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID) passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ @@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St }, nil } -func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) { model, _ := req["model"].(string) model = strings.TrimSpace(model) if model == "" { @@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut if len(messagesRaw) == 0 { return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID) passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go index f3453a2..a876364 100644 --- a/internal/adapter/openai/standard_request_test.go +++ b/internal/adapter/openai/standard_request_test.go @@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) { "temperature": 0.3, "stream": true, } - n, err := normalizeOpenAIChatRequest(store, req) + n, err := normalizeOpenAIChatRequest(store, req, "") if err != nil { t.Fatalf("normalize failed: %v", err) } @@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { "input": "ping", "instructions": "system", } - n, err := normalizeOpenAIResponsesRequest(store, req) + n, err := normalizeOpenAIResponsesRequest(store, req, "") if err != nil { t.Fatalf("normalize failed: %v", err) } diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index fd7222b..9c46649 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -11,6 +11,7 @@ type toolStreamSieveState struct { capture strings.Builder capturing bool recentTextTail string + disableDeltas bool toolNameSent bool toolName string toolArgsStart int @@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024 const toolSieveContextTailLimit = 256 func (s *toolStreamSieveState) resetIncrementalToolState() { + s.disableDeltas = false s.toolNameSent = false s.toolName = "" s.toolArgsStart = -1 @@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix } parsed := util.ParseStandaloneToolCalls(obj, toolNames) if len(parsed) == 0 { - if state.toolNameSent { - return prefixPart, nil, suffixPart, true - } return captured, nil, "", true } - if state.toolNameSent { - if len(parsed) > 1 { - return prefixPart, parsed[1:], suffixPart, true - } - return prefixPart, nil, suffixPart, true - } return prefixPart, parsed, suffixPart, true } @@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) { } func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + if state.disableDeltas { + return nil + } captured := state.capture.String() if captured == "" { return nil @@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { if insideCodeFence(state.recentTextTail + captured[:start]) { return nil } + certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx) + if hasMultiple { + state.disableDeltas = true + return nil + } + if !certainSingle { + // In uncertain phases (e.g. first call arrived but array not closed yet), + // avoid speculative deltas and wait for final parsed tool_calls payload. + return nil + } callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) if !ok { return nil @@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { return deltas } +func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return false, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return false, false + } + count := 0 + depth := 0 + quote := byte(0) + escaped := false + for ; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + if depth == 0 { + count++ + if count > 1 { + return false, true + } + } + depth++ + continue + } + if ch == '}' { + if depth > 0 { + depth-- + } + continue + } + if ch == ',' && depth == 0 { + // top-level separator means at least one more tool call exists + // (or is expected). Treat as multi-call and stop incremental deltas. + return false, true + } + if ch == ']' && depth == 0 { + return count == 1, false + } + } + // array not closed yet: still uncertain whether more calls will appear + return false, false +} + func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { arrStart, ok := findToolCallsArrayStart(text, keyIdx) if !ok { diff --git a/internal/adapter/openai/trace.go b/internal/adapter/openai/trace.go new file mode 100644 index 0000000..8ea58f0 --- /dev/null +++ b/internal/adapter/openai/trace.go @@ -0,0 +1,21 @@ +package openai + +import ( + "net/http" + "strings" + + "github.com/go-chi/chi/v5/middleware" +) + +func requestTraceID(r *http.Request) string { + if r == nil { + return "" + } + if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" { + return q + } + if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" { + return h + } + return strings.TrimSpace(middleware.GetReqID(r.Context())) +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 65006c4..f34ea8b 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return