feat: Implement request tracing and enhance tool call streaming stability by preventing speculative deltas and improving multi-call finalization.

This commit is contained in:
CJACK
2026-02-21 19:19:05 +08:00
parent e2cb07f08c
commit 13b1ec46ee
16 changed files with 549 additions and 96 deletions

View File

@@ -25,10 +25,11 @@ type chatStreamRuntime struct {
thinkingEnabled bool
searchEnabled bool
firstChunkSent bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
firstChunkSent bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
toolCallsDoneEmitted bool
toolSieve toolStreamSieveState
streamToolCallIDs map[int]string
@@ -96,7 +97,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
finalThinking := s.thinking.String()
finalText := s.text.String()
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) > 0 && !s.toolCallsEmitted {
if len(detected) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
delta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
@@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
nil,
))
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
} else if s.bufferToolContent {
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
if len(evt.ToolCalls) > 0 {
finishReason = "tool_calls"
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
s.firstChunkSent = true
}
s.sendChunk(openaifmt.BuildChatStreamChunk(
s.completionID,
s.created,
s.model,
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
nil,
))
}
if evt.Content == "" {
continue
}
@@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
if !s.emitEarlyToolDeltas {
continue
}
s.toolCallsEmitted = true
tcDelta := map[string]any{
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs),
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}
tcDelta := map[string]any{
"tool_calls": formatted,
}
s.toolCallsEmitted = true
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
s.firstChunkSent = true
@@ -202,6 +228,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
}

View File

@@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
"model": "my-model",
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
}
out, err := normalizeOpenAIChatRequest(cfg, req)
out, err := normalizeOpenAIChatRequest(cfg, req, "")
if err != nil {
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
}
@@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
aliases: map[string]string{},
wideInput: false,
}, req)
}, req, "")
if err == nil {
t.Fatal("expected error when wide input is disabled and only input is provided")
}
@@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
aliases: map[string]string{},
wideInput: true,
}, req)
}, req, "")
if err != nil {
t.Fatalf("unexpected error when wide input is enabled: %v", err)
}

View File

@@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return

View File

@@ -735,3 +735,71 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
foundSearch := false
foundEval := false
foundIndex1 := false
maxToolCallsInDelta := 0
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
if len(toolCalls) > maxToolCallsInDelta {
maxToolCallsInDelta = len(toolCalls)
}
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
foundIndex1 = true
}
fn, _ := tcm["function"].(map[string]any)
name, _ := fn["name"].(string)
switch name {
case "search_web":
foundSearch = true
case "eval_javascript":
foundEval = true
case "search_webeval_javascript":
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
}
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
}
}
}
}
if !foundSearch || !foundEval {
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
}
if maxToolCallsInDelta != 2 {
t.Fatalf("expected one tool_calls delta containing exactly two calls, max=%d body=%s", maxToolCallsInDelta, rec.Body.String())
}
if !foundIndex1 {
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

View File

@@ -4,9 +4,11 @@ import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/config"
)
func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
msg, ok := item.(map[string]any)
@@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
switch role {
case "assistant":
content := normalizeOpenAIContentForPrompt(msg["content"])
toolCalls := formatAssistantToolCallsForPrompt(msg)
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
combined := joinNonEmpty(content, toolCalls)
if combined == "" {
continue
@@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
return out
}
func formatAssistantToolCallsForPrompt(msg map[string]any) string {
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
entries := make([]string, 0)
if calls, ok := msg["tool_calls"].([]any); ok {
for i, item := range calls {
@@ -86,6 +88,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
if args == "" {
args = "{}"
}
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
}
}
@@ -99,6 +102,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
if args == "" {
args = "{}"
}
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
}
@@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string {
}
return strings.Join(nonEmpty, "\n\n")
}
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
if !looksLikeConcatenatedJSON(args) {
return
}
traceID = strings.TrimSpace(traceID)
if traceID == "" {
traceID = "unknown"
}
config.Logger.Warn(
"[openai] suspicious tool call history payload detected",
"trace_id", traceID,
"tool_call_id", strings.TrimSpace(callID),
"name", strings.TrimSpace(name),
"arguments_preview", previewToolArgs(args, 160),
)
}
func looksLikeConcatenatedJSON(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
return true
}
dec := json.NewDecoder(strings.NewReader(trimmed))
var first any
if err := dec.Decode(&first); err != nil {
return false
}
var second any
return dec.Decode(&second) == nil
}
func previewToolArgs(raw string, max int) string {
trimmed := strings.TrimSpace(raw)
if max <= 0 || len(trimmed) <= max {
return trimmed
}
return trimmed[:max]
}

View File

@@ -33,7 +33,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 4 {
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
}
@@ -68,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
t.Fatalf("expected serialized object in tool content, got %q", got)
@@ -89,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, "line-1\nline-2") {
t.Fatalf("expected joined text blocks, got %q", got)
@@ -108,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized message, got %d", len(normalized))
}
@@ -120,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
t.Fatalf("unexpected normalized function-role content: %q", got)
}
}
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
raw := []any{
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_search",
"type": "function",
"function": map[string]any{
"name": "search_web",
"arguments": `{"query":"latest ai news"}`,
},
},
map[string]any{
"id": "call_eval",
"type": "function",
"function": map[string]any{
"name": "eval_javascript",
"arguments": `{"code":"1+1"}`,
},
},
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
}
content, _ := normalized[0]["content"].(string)
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
}
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
t.Fatalf("missing first tool call block, got %q", content)
}
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
t.Fatalf("missing second tool call block, got %q", content)
}
if strings.Contains(content, "search_webeval_javascript") {
t.Fatalf("unexpected merged function name detected: %q", content)
}
if strings.Contains(content, `}{"`) {
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
}
}

View File

@@ -4,8 +4,8 @@ import (
"ds2api/internal/deepseek"
)
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
toolNames := []string{}
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
messages, toolNames = injectToolPrompt(messages, tools)

View File

@@ -40,7 +40,7 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
},
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
t.Fatalf("unexpected tool names: %#v", toolNames)
}
@@ -70,7 +70,7 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
},
}
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
}

View File

@@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return

View File

@@ -39,6 +39,7 @@ type responsesStreamRuntime struct {
streamToolCallIDs map[int]string
streamFunctionIDs map[int]string
functionDone map[int]bool
toolCallsDoneSigs map[string]bool
reasoningItemID string
persistResponse func(obj map[string]any)
@@ -73,6 +74,7 @@ func newResponsesStreamRuntime(
streamToolCallIDs: map[int]string{},
streamFunctionIDs: map[int]string{},
functionDone: map[int]bool{},
toolCallsDoneSigs: map[string]bool{},
persistResponse: persistResponse,
}
}
@@ -106,25 +108,8 @@ func (s *responsesStreamRuntime) finalize() {
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
}
if s.bufferToolContent {
for _, evt := range flushToolSieve(&s.sieve, s.toolNames) {
if evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
s.emitFunctionCallDoneEvents(evt.ToolCalls)
}
}
for _, evt := range flushToolSieve(&s.thinkingSieve, s.toolNames) {
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
s.emitFunctionCallDoneEvents(evt.ToolCalls)
}
}
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
}
// Compatibility fallback: some streams only emit incremental tool deltas.
// Ensure final function_call_arguments.done is emitted at least once.
@@ -141,9 +126,10 @@ func (s *responsesStreamRuntime) finalize() {
}
if len(detected) > 0 {
if !s.toolCallsDoneEmitted {
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs)))
s.emitToolCallsDone(detected)
} else {
s.emitFunctionCallDoneEvents(detected)
}
s.emitFunctionCallDoneEvents(detected)
}
}
@@ -186,22 +172,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
if s.bufferToolContent {
for _, evt := range processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames) {
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
s.emitFunctionCallDoneEvents(evt.ToolCalls)
}
}
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
}
continue
}
@@ -211,30 +182,56 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
continue
}
for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) {
if evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
s.emitFunctionCallDoneEvents(evt.ToolCalls)
}
}
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
for _, evt := range events {
if emitContent && evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
}
if len(evt.ToolCalls) > 0 {
s.emitToolCallsDone(evt.ToolCalls)
}
}
}
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
if len(calls) == 0 {
return
}
sig := toolCallListSignature(calls)
if sig != "" && s.toolCallsDoneSigs[sig] {
return
}
if sig != "" {
s.toolCallsDoneSigs[sig] = true
}
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
if len(formatted) == 0 {
return
}
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
s.emitFunctionCallDoneEvents(calls)
}
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
if strings.TrimSpace(s.reasoningItemID) != "" {
return s.reasoningItemID
@@ -356,3 +353,20 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any)
}
}
}
func toolCallListSignature(calls []util.ParsedToolCall) string {
if len(calls) == 0 {
return ""
}
var b strings.Builder
for i, tc := range calls {
if i > 0 {
b.WriteString("|")
}
b.WriteString(strings.TrimSpace(tc.Name))
b.WriteString(":")
args, _ := json.Marshal(tc.Input)
b.Write(args)
}
return b.String()
}

View File

@@ -246,6 +246,141 @@ func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T)
}
}
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
}
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seenNames := map[string]string{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
callID := strings.TrimSpace(asString(payload["call_id"]))
args := strings.TrimSpace(asString(payload["arguments"]))
if callID == "" {
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
}
if strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
}
if name == "search_webeval_javascript" {
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
}
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in done payload: %#v", payload)
}
seenNames[name] = callID
}
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
t.Fatalf("expected done events for both tools, got %#v", seenNames)
}
if seenNames["search_web"] == seenNames["eval_javascript"] {
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
}
completed, ok := extractSSEEventPayload(body, "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", body)
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
functionCallIDs := map[string]string{}
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil || m["type"] != "function_call" {
continue
}
name := strings.TrimSpace(asString(m["name"]))
callID := strings.TrimSpace(asString(m["call_id"]))
if name != "" && callID != "" {
functionCallIDs[name] = callID
}
}
if functionCallIDs["search_web"] != seenNames["search_web"] {
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
}
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
}
}
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, v string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected reasoning stream events, body=%s", body)
}
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seen := map[string]bool{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
if name == "search_webeval_javascript" {
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
}
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
}
args := strings.TrimSpace(asString(payload["arguments"]))
if strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
}
seen[name] = true
}
if !seen["search_web"] || !seen["eval_javascript"] {
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
}
}
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
@@ -271,3 +406,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
}
return nil, false
}
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
out := make([]map[string]any, 0, 2)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "event: ") {
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
matched = evt == targetEvent
continue
}
if !matched || !strings.HasPrefix(line, "data: ") {
continue
}
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if raw == "" || raw == "[DONE]" {
continue
}
var payload map[string]any
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
continue
}
out = append(out, payload)
}
return out
}

View File

@@ -8,7 +8,7 @@ import (
"ds2api/internal/util"
)
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
@@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
@@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
}, nil
}
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
@@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{

View File

@@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) {
"temperature": 0.3,
"stream": true,
}
n, err := normalizeOpenAIChatRequest(store, req)
n, err := normalizeOpenAIChatRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
@@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
"input": "ping",
"instructions": "system",
}
n, err := normalizeOpenAIResponsesRequest(store, req)
n, err := normalizeOpenAIResponsesRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}

View File

@@ -11,6 +11,7 @@ type toolStreamSieveState struct {
capture strings.Builder
capturing bool
recentTextTail string
disableDeltas bool
toolNameSent bool
toolName string
toolArgsStart int
@@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.disableDeltas = false
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
@@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
if state.toolNameSent {
return prefixPart, nil, suffixPart, true
}
return captured, nil, "", true
}
if state.toolNameSent {
if len(parsed) > 1 {
return prefixPart, parsed[1:], suffixPart, true
}
return prefixPart, nil, suffixPart, true
}
return prefixPart, parsed, suffixPart, true
}
@@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
}
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if state.disableDeltas {
return nil
}
captured := state.capture.String()
if captured == "" {
return nil
@@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
if hasMultiple {
state.disableDeltas = true
return nil
}
if !certainSingle {
// In uncertain phases (e.g. first call arrived but array not closed yet),
// avoid speculative deltas and wait for final parsed tool_calls payload.
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
@@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
return deltas
}
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return false, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return false, false
}
count := 0
depth := 0
quote := byte(0)
escaped := false
for ; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
if depth == 0 {
count++
if count > 1 {
return false, true
}
}
depth++
continue
}
if ch == '}' {
if depth > 0 {
depth--
}
continue
}
if ch == ',' && depth == 0 {
// top-level separator means at least one more tool call exists
// (or is expected). Treat as multi-call and stop incremental deltas.
return false, true
}
if ch == ']' && depth == 0 {
return count == 1, false
}
}
// array not closed yet: still uncertain whether more calls will appear
return false, false
}
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {

View File

@@ -0,0 +1,21 @@
package openai
import (
"net/http"
"strings"
"github.com/go-chi/chi/v5/middleware"
)
func requestTraceID(r *http.Request) string {
if r == nil {
return ""
}
if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" {
return q
}
if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" {
return h
}
return strings.TrimSpace(middleware.GetReqID(r.Context()))
}

View File

@@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return