diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 319aacd..d507849 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -309,6 +309,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt flusher.Flush() case line, ok := <-lines: if !ok { + // Ensure scanner completion is observed only after all queued + // SSE lines are drained, avoiding early finalize races. + _ = <-done finalize("stop") return } @@ -369,9 +372,6 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt "choices": newChoices, }) } - case <-done: - finalize("stop") - return } } } diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go new file mode 100644 index 0000000..df39d51 --- /dev/null +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -0,0 +1,283 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func makeSSEHTTPResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func decodeJSONBody(t *testing.T, body string) map[string]any { + t.Helper() + var out map[string]any + if err := json.Unmarshal([]byte(body), &out); err != nil { + t.Fatalf("decode json failed: %v, body=%s", err, body) + } + return out +} + +func parseSSEDataFrames(t *testing.T, body string) ([]map[string]any, bool) { + t.Helper() + lines := strings.Split(body, "\n") + frames := make([]map[string]any, 0, len(lines)) + done := false + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + if payload == "[DONE]" { + done = true + continue + } + var frame map[string]any + if err := json.Unmarshal([]byte(payload), &frame); err != nil { + t.Fatalf("decode sse frame failed: %v, payload=%s", err, payload) + } + frames = append(frames, frame) + } + return frames, done +} + +func streamHasRawToolJSONContent(frames []map[string]any) bool { + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + content, _ := delta["content"].(string) + if strings.Contains(content, `"tool_calls"`) { + return true + } + } + } + return false +} + +func streamHasToolCallsDelta(frames []map[string]any) bool { + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if _, ok := delta["tool_calls"]; ok { + return true + } + } + } + return false +} + +func streamFinishReason(frames []map[string]any) string { + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + if reason, ok := choice["finish_reason"].(string); ok && reason != "" { + return reason + } + } + } + return "" +} + +func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + if len(choices) != 1 { + t.Fatalf("unexpected choices: %#v", out["choices"]) + } + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if msg["content"] != nil { + t.Fatalf("expected content nil, got %#v", msg["content"]) + } + toolCalls, _ := msg["tool_calls"].([]any) + if len(toolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"]) + } +} + +func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/thinking_content","v":"先想一下"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + msg, _ := choice["message"].(map[string]any) + if msg["reasoning_content"] != "先想一下" { + t.Fatalf("expected reasoning_content, got %#v", msg["reasoning_content"]) + } + if msg["content"] != nil { + t.Fatalf("expected content nil, got %#v", msg["content"]) + } + if choice["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) + } +} + +func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if msg["content"] != nil { + t.Fatalf("expected content nil, got %#v", msg["content"]) + } + toolCalls, _ := msg["tool_calls"].([]any) + if len(toolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"]) + } +} + +func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, + `data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid3", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/thinking_content","v":"思考中"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid4", "deepseek-reasoner", "prompt", true, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } + + hasThinkingDelta := false + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if _, ok := delta["reasoning_content"]; ok { + hasThinkingDelta = true + } + } + } + if !hasThinkingDelta { + t.Fatalf("expected reasoning_content delta in reasoner stream: %s", rec.Body.String()) + } +} + +func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid5", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go index d15160c..a594a6a 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls.go @@ -52,6 +52,20 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { } out = append(out, tc) } + // If the model clearly emitted tool_calls JSON but all names are outside the + // declared set, keep the parsed calls as a fallback so upper layers can still + // intercept structured tool output instead of leaking raw JSON to users. + if len(out) == 0 && len(parsed) > 0 { + for _, tc := range parsed { + if tc.Name == "" { + continue + } + if tc.Input == nil { + tc.Input = map[string]any{} + } + out = append(out, tc) + } + } return out } diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index de27b13..8c44320 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -41,11 +41,14 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) { } } -func TestParseToolCallsRejectUnknown(t *testing.T) { +func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) { text := `{"tool_calls":[{"name":"unknown","input":{}}]}` calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 0 { - t.Fatalf("expected 0 calls, got %d", len(calls)) + if len(calls) != 1 { + t.Fatalf("expected fallback 1 call, got %d", len(calls)) + } + if calls[0].Name != "unknown" { + t.Fatalf("unexpected name: %s", calls[0].Name) } }