diff --git a/internal/adapter/claude/stream_status_test.go b/internal/adapter/claude/stream_status_test.go new file mode 100644 index 0000000..c3936de --- /dev/null +++ b/internal/adapter/claude/stream_status_test.go @@ -0,0 +1,100 @@ +package claude + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + chimw "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/auth" +) + +type streamStatusClaudeAuthStub struct{} + +func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {} + +type streamStatusClaudeDSStub struct{} + +func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n" + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: ioNopCloser{strings.NewReader(body)}, + }, nil +} + +type ioNopCloser struct { + *strings.Reader +} + +func (ioNopCloser) Close() error { return nil } + +type streamStatusClaudeStoreStub struct{} + +func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string { + return map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + } +} + +func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(ww, r) + *statuses = append(*statuses, ww.Status()) + }) + } +} + +func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: streamStatusClaudeStoreStub{}, + Auth: streamStatusClaudeAuthStub{}, + DS: streamStatusClaudeDSStub{}, + } + r := chi.NewRouter() + r.Use(captureClaudeStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go new file mode 100644 index 0000000..4f8305a --- /dev/null +++ b/internal/adapter/openai/stream_status_test.go @@ -0,0 +1,185 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + chimw "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/auth" +) + +type streamStatusAuthStub struct{} + +func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {} + +type streamStatusDSStub struct { + resp *http.Response +} + +func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + return m.resp, nil +} + +func makeOpenAISSEHTTPResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(ww, r) + *statuses = append(*statuses, ww.Status()) + }) + } +} + +func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} + +func TestResponsesStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} + +func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) { + statuses := make([]int, 0, 1) + content, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", + }) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String()) + } + outputText, _ := out["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText) + } + output, _ := out["output"].([]any) + hasFunctionCall := false + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + hasFunctionCall = true + break + } + } + if !hasFunctionCall { + t.Fatalf("expected function_call output item, got %#v", output) + } +}