diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index cf2420e..2027729 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -759,16 +759,18 @@ func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) { foundSearch := false foundEval := false foundIndex1 := false - maxToolCallsInDelta := 0 + toolCallsDeltaLens := make([]int, 0, 2) for _, frame := range frames { choices, _ := frame["choices"].([]any) for _, item := range choices { choice, _ := item.(map[string]any) delta, _ := choice["delta"].(map[string]any) - toolCalls, _ := delta["tool_calls"].([]any) - if len(toolCalls) > maxToolCallsInDelta { - maxToolCallsInDelta = len(toolCalls) + rawToolCalls, hasToolCalls := delta["tool_calls"] + if !hasToolCalls { + continue } + toolCalls, _ := rawToolCalls.([]any) + toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls)) for _, tc := range toolCalls { tcm, _ := tc.(map[string]any) if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 { @@ -793,8 +795,8 @@ func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) { if !foundSearch || !foundEval { t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String()) } - if maxToolCallsInDelta != 2 { - t.Fatalf("expected one tool_calls delta containing exactly two calls, max=%d body=%s", maxToolCallsInDelta, rec.Body.String()) + if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 { + t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String()) } if !foundIndex1 { t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String()) diff --git a/internal/adapter/openai/trace_test.go b/internal/adapter/openai/trace_test.go new file mode 100644 index 0000000..cbacbf3 --- /dev/null +++ b/internal/adapter/openai/trace_test.go @@ -0,0 +1,47 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5/middleware" +) + +func traceIDViaMiddleware(req *http.Request) string { + if req == nil { + return requestTraceID(nil) + } + var got string + h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + got = requestTraceID(r) + })) + h.ServeHTTP(httptest.NewRecorder(), req) + return got +} + +func TestRequestTraceIDPriority(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil) + req.Header.Set("X-Ds2-Test-Trace", "header-trace") + got := traceIDViaMiddleware(req) + if got != "query-trace" { + t.Fatalf("expected query trace id to win, got %q", got) + } +} + +func TestRequestTraceIDHeaderFallback(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + req.Header.Set("X-Ds2-Test-Trace", "header-trace") + got := traceIDViaMiddleware(req) + if got != "header-trace" { + t.Fatalf("expected header trace id to win when query missing, got %q", got) + } +} + +func TestRequestTraceIDReqIDFallback(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + got := traceIDViaMiddleware(req) + if got == "" { + t.Fatal("expected middleware request id fallback to be non-empty") + } +}