mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 08:55:28 +08:00
feat: Implement request tracing and enhance tool call streaming stability by preventing speculative deltas and improving multi-call finalization.
This commit is contained in:
@@ -25,10 +25,11 @@ type chatStreamRuntime struct {
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
firstChunkSent bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
firstChunkSent bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
toolSieve toolStreamSieveState
|
||||
streamToolCallIDs map[int]string
|
||||
@@ -96,7 +97,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsEmitted {
|
||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
|
||||
@@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
} else if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
@@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs),
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatted,
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
@@ -202,6 +228,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
"model": "my-model",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||
}
|
||||
out, err := normalizeOpenAIChatRequest(cfg, req)
|
||||
out, err := normalizeOpenAIChatRequest(cfg, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
|
||||
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: false,
|
||||
}, req)
|
||||
}, req, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when wide input is disabled and only input is provided")
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
|
||||
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: true,
|
||||
}, req)
|
||||
}, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error when wide input is enabled: %v", err)
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
|
||||
@@ -735,3 +735,71 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
|
||||
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundSearch := false
|
||||
foundEval := false
|
||||
foundIndex1 := false
|
||||
maxToolCallsInDelta := 0
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
if len(toolCalls) > maxToolCallsInDelta {
|
||||
maxToolCallsInDelta = len(toolCalls)
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
|
||||
foundIndex1 = true
|
||||
}
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
name, _ := fn["name"].(string)
|
||||
switch name {
|
||||
case "search_web":
|
||||
foundSearch = true
|
||||
case "eval_javascript":
|
||||
foundEval = true
|
||||
case "search_webeval_javascript":
|
||||
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
|
||||
}
|
||||
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSearch || !foundEval {
|
||||
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
|
||||
}
|
||||
if maxToolCallsInDelta != 2 {
|
||||
t.Fatalf("expected one tool_calls delta containing exactly two calls, max=%d body=%s", maxToolCallsInDelta, rec.Body.String())
|
||||
}
|
||||
if !foundIndex1 {
|
||||
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, ok := item.(map[string]any)
|
||||
@@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
switch role {
|
||||
case "assistant":
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg)
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
|
||||
combined := joinNonEmpty(content, toolCalls)
|
||||
if combined == "" {
|
||||
continue
|
||||
@@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
|
||||
entries := make([]string, 0)
|
||||
if calls, ok := msg["tool_calls"].([]any); ok {
|
||||
for i, item := range calls {
|
||||
@@ -86,6 +88,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
|
||||
}
|
||||
}
|
||||
@@ -99,6 +102,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
|
||||
}
|
||||
|
||||
@@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string {
|
||||
}
|
||||
return strings.Join(nonEmpty, "\n\n")
|
||||
}
|
||||
|
||||
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
|
||||
if !looksLikeConcatenatedJSON(args) {
|
||||
return
|
||||
}
|
||||
traceID = strings.TrimSpace(traceID)
|
||||
if traceID == "" {
|
||||
traceID = "unknown"
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[openai] suspicious tool call history payload detected",
|
||||
"trace_id", traceID,
|
||||
"tool_call_id", strings.TrimSpace(callID),
|
||||
"name", strings.TrimSpace(name),
|
||||
"arguments_preview", previewToolArgs(args, 160),
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeConcatenatedJSON(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
|
||||
return true
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
var first any
|
||||
if err := dec.Decode(&first); err != nil {
|
||||
return false
|
||||
}
|
||||
var second any
|
||||
return dec.Decode(&second) == nil
|
||||
}
|
||||
|
||||
func previewToolArgs(raw string, max int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if max <= 0 || len(trimmed) <= max {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:max]
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 4 {
|
||||
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
|
||||
t.Fatalf("expected serialized object in tool content, got %q", got)
|
||||
@@ -89,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "line-1\nline-2") {
|
||||
t.Fatalf("expected joined text blocks, got %q", got)
|
||||
@@ -108,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
@@ -120,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_search",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search_web",
|
||||
"arguments": `{"query":"latest ai news"}`,
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"id": "call_eval",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "eval_javascript",
|
||||
"arguments": `{"code":"1+1"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
|
||||
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
|
||||
t.Fatalf("missing first tool call block, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
|
||||
t.Fatalf("missing second tool call block, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, "search_webeval_javascript") {
|
||||
t.Fatalf("unexpected merged function name detected: %q", content)
|
||||
}
|
||||
if strings.Contains(content, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
|
||||
@@ -39,6 +39,7 @@ type responsesStreamRuntime struct {
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionDone map[int]bool
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
@@ -73,6 +74,7 @@ func newResponsesStreamRuntime(
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
@@ -106,25 +108,8 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
|
||||
}
|
||||
if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.sieve, s.toolNames) {
|
||||
if evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
for _, evt := range flushToolSieve(&s.thinkingSieve, s.toolNames) {
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
// Compatibility fallback: some streams only emit incremental tool deltas.
|
||||
// Ensure final function_call_arguments.done is emitted at least once.
|
||||
@@ -141,9 +126,10 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs)))
|
||||
s.emitToolCallsDone(detected)
|
||||
} else {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,22 +172,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
|
||||
if s.bufferToolContent {
|
||||
for _, evt := range processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames) {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -211,30 +182,56 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) {
|
||||
if evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs)))
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitToolCallsDone(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
|
||||
if len(calls) == 0 {
|
||||
return
|
||||
}
|
||||
sig := toolCallListSignature(calls)
|
||||
if sig != "" && s.toolCallsDoneSigs[sig] {
|
||||
return
|
||||
}
|
||||
if sig != "" {
|
||||
s.toolCallsDoneSigs[sig] = true
|
||||
}
|
||||
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
return
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDoneEvents(calls)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
if strings.TrimSpace(s.reasoningItemID) != "" {
|
||||
return s.reasoningItemID
|
||||
@@ -356,3 +353,20 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallListSignature(calls []util.ParsedToolCall) string {
|
||||
if len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
b.WriteString("|")
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(tc.Name))
|
||||
b.WriteString(":")
|
||||
args, _ := json.Marshal(tc.Input)
|
||||
b.Write(args)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -246,6 +246,141 @@ func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
|
||||
}
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
|
||||
t.Fatalf("expected done events for both tools, got %#v", seenNames)
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
functionCallIDs := map[string]string{}
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
callID := strings.TrimSpace(asString(m["call_id"]))
|
||||
if name != "" && callID != "" {
|
||||
functionCallIDs[name] = callID
|
||||
}
|
||||
}
|
||||
if functionCallIDs["search_web"] != seenNames["search_web"] {
|
||||
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
|
||||
}
|
||||
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
|
||||
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected reasoning stream events, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seen := map[string]bool{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
seen[name] = true
|
||||
}
|
||||
if !seen["search_web"] || !seen["eval_javascript"] {
|
||||
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
@@ -271,3 +406,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
out := make([]map[string]any, 0, 2)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||
matched = evt == targetEvent
|
||||
continue
|
||||
}
|
||||
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" || raw == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, payload)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
|
||||
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
@@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
|
||||
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
@@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||
"temperature": 0.3,
|
||||
"stream": true,
|
||||
}
|
||||
n, err := normalizeOpenAIChatRequest(store, req)
|
||||
n, err := normalizeOpenAIChatRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
"input": "ping",
|
||||
"instructions": "system",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req)
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type toolStreamSieveState struct {
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
@@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
s.toolNameSent = false
|
||||
s.toolName = ""
|
||||
s.toolArgsStart = -1
|
||||
@@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
if state.toolNameSent {
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return captured, nil, "", true
|
||||
}
|
||||
if state.toolNameSent {
|
||||
if len(parsed) > 1 {
|
||||
return prefixPart, parsed[1:], suffixPart, true
|
||||
}
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
|
||||
@@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
}
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if state.disableDeltas {
|
||||
return nil
|
||||
}
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
@@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||
if hasMultiple {
|
||||
state.disableDeltas = true
|
||||
return nil
|
||||
}
|
||||
if !certainSingle {
|
||||
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
return deltas
|
||||
}
|
||||
|
||||
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return false, false
|
||||
}
|
||||
count := 0
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for ; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
if depth == 0 {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == ',' && depth == 0 {
|
||||
// top-level separator means at least one more tool call exists
|
||||
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||
return false, true
|
||||
}
|
||||
if ch == ']' && depth == 0 {
|
||||
return count == 1, false
|
||||
}
|
||||
}
|
||||
// array not closed yet: still uncertain whether more calls will appear
|
||||
return false, false
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
|
||||
21
internal/adapter/openai/trace.go
Normal file
21
internal/adapter/openai/trace.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func requestTraceID(r *http.Request) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" {
|
||||
return q
|
||||
}
|
||||
if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" {
|
||||
return h
|
||||
}
|
||||
return strings.TrimSpace(middleware.GetReqID(r.Context()))
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user