diff --git a/API.en.md b/API.en.md index 122d9c9..8836bc5 100644 --- a/API.en.md +++ b/API.en.md @@ -260,24 +260,39 @@ If tool use is detected, `stop_reason` becomes `tool_use` and `content` contains ### Claude Streaming (`stream=true`) -Still SSE, but current implementation writes `data:` lines only (no `event:` lines). Event type is carried in JSON `type`. +SSE uses paired `event:` + `data:` lines. Event type is also carried in JSON `type`. Example: ```text +event: message_start data: {"type":"message_start","message":{...}} +event: content_block_start data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} +event: content_block_delta data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello"}} +event: ping +data: {"type":"ping"} + +event: content_block_stop data: {"type":"content_block_stop","index":0} +event: message_delta data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}} +event: message_stop data: {"type":"message_stop"} ``` +Notes: + +- Thinking-enabled models stream `thinking_delta`. +- `signature_delta` is not emitted because DeepSeek does not provide verifiable thinking signatures. +- In `tools` mode, the stream prioritizes avoiding raw tool JSON leakage and does not force `input_json_delta` partials. + ### `POST /anthropic/v1/messages/count_tokens` Request example: diff --git a/API.md b/API.md index e66ba30..243838f 100644 --- a/API.md +++ b/API.md @@ -264,24 +264,39 @@ anthropic-version: 2023-06-01 ### Claude 流式(`stream=true`) -返回同样是 SSE,但当前实现仅写入 `data:` 行,不输出 `event:` 行。每条 JSON 内包含 `type` 字段。 +返回 SSE,包含 `event:` + `data:` 双行;JSON 中仍保留 `type` 字段。 示例: ```text +event: message_start data: {"type":"message_start","message":{...}} +event: content_block_start data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} +event: content_block_delta data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello"}} +event: ping +data: {"type":"ping"} + +event: content_block_stop data: {"type":"content_block_stop","index":0} +event: message_delta data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}} +event: message_stop data: {"type":"message_stop"} ``` +说明: + +- 开启思维模型时会输出 `thinking_delta`。 +- 当前不会输出 `signature_delta`(上游 DeepSeek 未提供可验证签名)。 +- `tools` 场景优先避免泄露原始工具 JSON,不强制发送 `input_json_delta`。 + ### `POST /anthropic/v1/messages/count_tokens` 请求示例: diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index 8abd829..19d32ae 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -24,6 +24,12 @@ type Handler struct { DS *deepseek.Client } +var ( + claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second + claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second + claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount +) + func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/anthropic/v1/models", h.ListModels) r.Post("/anthropic/v1/messages", h.Messages) @@ -74,7 +80,6 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { thinkingEnabled = false searchEnabled = false } - _ = searchEnabled finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) @@ -107,13 +112,13 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { return } - fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled) toolNames := extractClaudeToolNames(toolsRequested) - detected := util.ParseToolCalls(fullText, toolNames) if toBool(req["stream"]) { - h.writeClaudeStream(w, r, model, normalized, fullText, detected) + h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames) return } + fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled) + detected := util.ParseToolCalls(fullText, toolNames) content := make([]map[string]any, 0, 4) if fullThinking != "" { content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) @@ -228,7 +233,14 @@ func collectDeepSeek(resp *http.Response, thinkingEnabled bool) (string, string) return text.String(), thinking.String() } -func (h *Handler) writeClaudeStream(w http.ResponseWriter, r *http.Request, model string, messages []any, fullText string, detected []util.ParsedToolCall) { +func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) + return + } + w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") @@ -238,8 +250,25 @@ func (h *Handler) writeClaudeStream(w http.ResponseWriter, r *http.Request, mode if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } - send := func(v any) { + lines := make(chan []byte, 128) + done := make(chan error, 1) + go func() { + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + b := append([]byte{}, scanner.Bytes()...) + lines <- b + } + close(lines) + done <- scanner.Err() + }() + + send := func(event string, v any) { b, _ := json.Marshal(v) + _, _ = w.Write([]byte("event: ")) + _, _ = w.Write([]byte(event)) + _, _ = w.Write([]byte("\n")) _, _ = w.Write([]byte("data: ")) _, _ = w.Write(b) _, _ = w.Write([]byte("\n\n")) @@ -247,9 +276,23 @@ func (h *Handler) writeClaudeStream(w http.ResponseWriter, r *http.Request, mode _ = rc.Flush() } } + sendError := func(message string) { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "upstream stream error" + } + send("error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": msg, + }, + }) + } + messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages)) - send(map[string]any{ + send("message_start", map[string]any{ "type": "message_start", "message": map[string]any{ "id": messageID, @@ -262,26 +305,247 @@ func (h *Handler) writeClaudeStream(w http.ResponseWriter, r *http.Request, mode "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, }, }) - outputTokens := 0 - stopReason := "end_turn" - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - send(map[string]any{"type": "content_block_start", "index": i, "content_block": map[string]any{"type": "tool_use", "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), "name": tc.Name, "input": tc.Input}}) - send(map[string]any{"type": "content_block_stop", "index": i}) - outputTokens += util.EstimateTokens(fmt.Sprintf("%v", tc.Input)) + + currentType := "text" + if thinkingEnabled { + currentType = "thinking" + } + bufferToolContent := len(toolNames) > 0 + hasContent := false + lastContent := time.Now() + keepaliveCount := 0 + + thinking := strings.Builder{} + text := strings.Builder{} + + nextBlockIndex := 0 + thinkingBlockOpen := false + thinkingBlockIndex := -1 + textBlockOpen := false + textBlockIndex := -1 + ended := false + + closeThinkingBlock := func() { + if !thinkingBlockOpen { + return } - } else { - if fullText != "" { - send(map[string]any{"type": "content_block_start", "index": 0, "content_block": map[string]any{"type": "text", "text": ""}}) - send(map[string]any{"type": "content_block_delta", "index": 0, "delta": map[string]any{"type": "text_delta", "text": fullText}}) - send(map[string]any{"type": "content_block_stop", "index": 0}) - outputTokens = util.EstimateTokens(fullText) + send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": thinkingBlockIndex, + }) + thinkingBlockOpen = false + thinkingBlockIndex = -1 + } + closeTextBlock := func() { + if !textBlockOpen { + return + } + send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": textBlockIndex, + }) + textBlockOpen = false + textBlockIndex = -1 + } + + finalize := func(stopReason string) { + if ended { + return + } + ended = true + + closeThinkingBlock() + closeTextBlock() + + finalThinking := thinking.String() + finalText := text.String() + + if bufferToolContent { + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + idx := nextBlockIndex + i + send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), + "name": tc.Name, + "input": tc.Input, + }, + }) + send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + nextBlockIndex += len(detected) + } else if finalText != "" { + idx := nextBlockIndex + nextBlockIndex++ + send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "text_delta", + "text": finalText, + }, + }) + send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + } + + outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + send("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]any{ + "output_tokens": outputTokens, + }, + }) + send("message_stop", map[string]any{"type": "message_stop"}) + } + + pingTicker := time.NewTicker(claudeStreamPingInterval) + defer pingTicker.Stop() + + for { + select { + case <-r.Context().Done(): + return + case <-pingTicker.C: + if !hasContent { + keepaliveCount++ + if keepaliveCount >= claudeStreamMaxKeepaliveCnt { + finalize("end_turn") + return + } + } + if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout { + finalize("end_turn") + return + } + send("ping", map[string]any{"type": "ping"}) + case line, ok := <-lines: + if !ok { + if err := <-done; err != nil { + sendError(err.Error()) + return + } + finalize("end_turn") + return + } + + chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line) + if !parsed { + continue + } + if doneSignal { + finalize("end_turn") + return + } + if errObj, hasErr := chunk["error"]; hasErr { + sendError(fmt.Sprintf("%v", errObj)) + return + } + if code, _ := chunk["code"].(string); code == "content_filter" { + sendError("content filtered by upstream") + return + } + parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) + currentType = newType + if finished { + finalize("end_turn") + return + } + + for _, p := range parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { + continue + } + + hasContent = true + lastContent = time.Now() + keepaliveCount = 0 + + if p.Type == "thinking" { + if !thinkingEnabled { + continue + } + thinking.WriteString(p.Text) + closeTextBlock() + if !thinkingBlockOpen { + thinkingBlockIndex = nextBlockIndex + nextBlockIndex++ + send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": thinkingBlockIndex, + "content_block": map[string]any{ + "type": "thinking", + "thinking": "", + }, + }) + thinkingBlockOpen = true + } + send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": thinkingBlockIndex, + "delta": map[string]any{ + "type": "thinking_delta", + "thinking": p.Text, + }, + }) + continue + } + + text.WriteString(p.Text) + if bufferToolContent { + continue + } + closeThinkingBlock() + if !textBlockOpen { + textBlockIndex = nextBlockIndex + nextBlockIndex++ + send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": textBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + textBlockOpen = true + } + send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": textBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": p.Text, + }, + }) + } } } - send(map[string]any{"type": "message_delta", "delta": map[string]any{"stop_reason": stopReason, "stop_sequence": nil}, "usage": map[string]any{"output_tokens": outputTokens}}) - send(map[string]any{"type": "message_stop"}) - _ = r } func normalizeClaudeMessages(messages []any) []any { diff --git a/internal/adapter/claude/handler_stream_test.go b/internal/adapter/claude/handler_stream_test.go new file mode 100644 index 0000000..56f7ea9 --- /dev/null +++ b/internal/adapter/claude/handler_stream_test.go @@ -0,0 +1,256 @@ +package claude + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +type claudeFrame struct { + Event string + Payload map[string]any +} + +func makeClaudeSSEHTTPResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func parseClaudeFrames(t *testing.T, body string) []claudeFrame { + t.Helper() + chunks := strings.Split(body, "\n\n") + frames := make([]claudeFrame, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + lines := strings.Split(chunk, "\n") + eventName := "" + dataPayload := "" + for _, line := range lines { + line = strings.TrimSpace(line) + switch { + case strings.HasPrefix(line, "event:"): + eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + case strings.HasPrefix(line, "data:"): + dataPayload = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + if eventName == "" || dataPayload == "" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(dataPayload), &payload); err != nil { + t.Fatalf("decode frame failed: %v, payload=%s", err, dataPayload) + } + frames = append(frames, claudeFrame{Event: eventName, Payload: payload}) + } + return frames +} + +func findClaudeFrames(frames []claudeFrame, event string) []claudeFrame { + out := make([]claudeFrame, 0) + for _, f := range frames { + if f.Event == event { + out = append(out, f) + } + } + return out +} + +func TestHandleClaudeStreamRealtimeTextIncrementsWithEventHeaders(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/content","v":"Hel"}`, + `data: {"p":"response/content","v":"lo"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-20250514", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil) + + body := rec.Body.String() + if !strings.Contains(body, "event: message_start") { + t.Fatalf("missing event header: message_start, body=%s", body) + } + if !strings.Contains(body, "event: content_block_delta") { + t.Fatalf("missing event header: content_block_delta, body=%s", body) + } + if !strings.Contains(body, "event: message_stop") { + t.Fatalf("missing event header: message_stop, body=%s", body) + } + + frames := parseClaudeFrames(t, body) + deltas := findClaudeFrames(frames, "content_block_delta") + if len(deltas) < 2 { + t.Fatalf("expected at least 2 text deltas, got=%d body=%s", len(deltas), body) + } + combined := strings.Builder{} + for _, f := range deltas { + delta, _ := f.Payload["delta"].(map[string]any) + if delta["type"] == "text_delta" { + combined.WriteString(asString(delta["text"])) + } + } + if combined.String() != "Hello" { + t.Fatalf("unexpected combined text: %q body=%s", combined.String(), body) + } +} + +func TestHandleClaudeStreamRealtimeThinkingDelta(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/thinking_content","v":"思"}`, + `data: {"p":"response/thinking_content","v":"考"}`, + `data: {"p":"response/content","v":"ok"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-20250514", []any{map[string]any{"role": "user", "content": "hi"}}, true, false, nil) + + frames := parseClaudeFrames(t, rec.Body.String()) + foundThinkingDelta := false + for _, f := range findClaudeFrames(frames, "content_block_delta") { + delta, _ := f.Payload["delta"].(map[string]any) + if delta["type"] == "thinking_delta" { + foundThinkingDelta = true + break + } + } + if !foundThinkingDelta { + t.Fatalf("expected thinking_delta event, body=%s", rec.Body.String()) + } +} + +func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, + `data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-20250514", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"search"}) + + frames := parseClaudeFrames(t, rec.Body.String()) + for _, f := range findClaudeFrames(frames, "content_block_delta") { + delta, _ := f.Payload["delta"].(map[string]any) + if delta["type"] == "text_delta" && strings.Contains(asString(delta["text"]), `"tool_calls"`) { + t.Fatalf("raw tool_calls JSON leaked in text delta: body=%s", rec.Body.String()) + } + } + + foundToolUse := false + for _, f := range findClaudeFrames(frames, "content_block_start") { + contentBlock, _ := f.Payload["content_block"].(map[string]any) + if contentBlock["type"] == "tool_use" { + foundToolUse = true + break + } + } + if !foundToolUse { + t.Fatalf("expected tool_use block in stream, body=%s", rec.Body.String()) + } + + foundToolUseStop := false + for _, f := range findClaudeFrames(frames, "message_delta") { + delta, _ := f.Payload["delta"].(map[string]any) + if delta["stop_reason"] == "tool_use" { + foundToolUseStop = true + break + } + } + if !foundToolUseStop { + t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String()) + } +} + +func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"error":{"message":"boom"}}`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-20250514", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil) + + frames := parseClaudeFrames(t, rec.Body.String()) + errFrames := findClaudeFrames(frames, "error") + if len(errFrames) == 0 { + t.Fatalf("expected error event frame, body=%s", rec.Body.String()) + } + if errFrames[0].Payload["type"] != "error" { + t.Fatalf("expected error payload type, body=%s", rec.Body.String()) + } +} + +func TestHandleClaudeStreamRealtimePingEvent(t *testing.T) { + h := &Handler{} + oldPing := claudeStreamPingInterval + oldIdle := claudeStreamIdleTimeout + oldKeepalive := claudeStreamMaxKeepaliveCnt + claudeStreamPingInterval = 10 * time.Millisecond + claudeStreamIdleTimeout = 300 * time.Millisecond + claudeStreamMaxKeepaliveCnt = 50 + defer func() { + claudeStreamPingInterval = oldPing + claudeStreamIdleTimeout = oldIdle + claudeStreamMaxKeepaliveCnt = oldKeepalive + }() + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: pr} + go func() { + time.Sleep(40 * time.Millisecond) + _, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hi\"}\n") + _, _ = io.WriteString(pw, "data: [DONE]\n") + _ = pw.Close() + }() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-20250514", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil) + + frames := parseClaudeFrames(t, rec.Body.String()) + if len(findClaudeFrames(frames, "ping")) == 0 { + t.Fatalf("expected ping event in stream, body=%s", rec.Body.String()) + } +} + +func TestCollectDeepSeekRegression(t *testing.T) { + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/thinking_content","v":"想"}`, + `data: {"p":"response/content","v":"答"}`, + `data: [DONE]`, + ) + text, thinking := collectDeepSeek(resp, true) + if thinking != "想" { + t.Fatalf("unexpected thinking: %q", thinking) + } + if text != "答" { + t.Fatalf("unexpected text: %q", text) + } +} + +func asString(v any) string { + s, _ := v.(string) + return s +} diff --git a/internal/testsuite/edge_cases.go b/internal/testsuite/edge_cases.go new file mode 100644 index 0000000..784f8ab --- /dev/null +++ b/internal/testsuite/edge_cases.go @@ -0,0 +1,494 @@ +package testsuite + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" +) + +func (r *Runner) caseConcurrencyThresholdLimit(ctx context.Context, cc *caseContext) error { + status, err := r.fetchQueueStatus(ctx, cc) + if err != nil { + return err + } + total := toInt(status["total"]) + maxInflight := toInt(status["max_inflight_per_account"]) + maxQueue := toInt(status["max_queue_size"]) + if total <= 0 || maxInflight <= 0 { + cc.assert("queue_capacity_known", false, fmt.Sprintf("queue_status=%v", status)) + return nil + } + capacity := total*maxInflight + maxQueue + if capacity <= 0 { + capacity = total * maxInflight + } + n := capacity + 8 + if n < 8 { + n = 8 + } + type one struct { + Status int + Err string + } + res := make([]one, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": fmt.Sprintf("并发边界测试 #%d,请输出不少于300字。", idx)}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + res[idx] = one{Err: err.Error()} + return + } + res[idx] = one{Status: resp.StatusCode} + }(i) + } + wg.Wait() + + dist := map[int]int{} + for _, it := range res { + if it.Status > 0 { + dist[it.Status]++ + } + } + cc.assert("has_200", dist[http.StatusOK] > 0, fmt.Sprintf("distribution=%v", dist)) + cc.assert("has_429_when_over_capacity", dist[http.StatusTooManyRequests] > 0, fmt.Sprintf("distribution=%v capacity=%d n=%d", dist, capacity, n)) + _, has5xx := has5xx(dist) + cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist)) + return nil +} + +func (r *Runner) caseStreamAbortRelease(ctx context.Context, cc *caseContext) error { + before, err := r.fetchQueueStatus(ctx, cc) + if err != nil { + return err + } + baseInUse := toInt(before["in_use"]) + for i := 0; i < 3; i++ { + if err := cc.abortStreamRequest(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": fmt.Sprintf("中断释放测试 #%d,请流式回复", i)}, + }, + "stream": true, + }, + Stream: true, + }); err != nil { + cc.assert("abort_request_no_error", false, err.Error()) + } + } + + deadline := time.Now().Add(25 * time.Second) + recovered := false + lastInUse := -1 + for time.Now().Before(deadline) { + st, err := r.fetchQueueStatus(ctx, cc) + if err != nil { + time.Sleep(500 * time.Millisecond) + continue + } + lastInUse = toInt(st["in_use"]) + if lastInUse <= baseInUse { + recovered = true + break + } + time.Sleep(time.Second) + } + cc.assert("in_use_recovered_after_abort", recovered, fmt.Sprintf("base=%d last=%d", baseInUse, lastInUse)) + return nil +} + +func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error { + cc.seq++ + traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq) + cc.traceIDsSet[traceID] = struct{}{} + fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) + if err != nil { + return err + } + headers := map[string]string{} + for k, v := range spec.Headers { + headers[k] = v + } + headers["X-Ds2-Test-Trace"] = traceID + bodyBytes, _ := json.Marshal(spec.Body) + headers["Content-Type"] = "application/json" + cc.requests = append(cc.requests, requestLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + Method: spec.Method, + URL: fullURL, + Headers: headers, + Body: spec.Body, + Timestamp: time.Now().Format(time.RFC3339Nano), + }) + + reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) + if err != nil { + return err + } + for k, v := range headers { + req.Header.Set(k, v) + } + start := time.Now() + resp, err := cc.runner.httpClient.Do(req) + if err != nil { + cc.responses = append(cc.responses, responseLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + StatusCode: 0, + DurationMS: time.Since(start).Milliseconds(), + NetworkErr: err.Error(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + return err + } + defer resp.Body.Close() + buf := make([]byte, 512) + _, _ = resp.Body.Read(buf) + _ = resp.Body.Close() + cc.responses = append(cc.responses, responseLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + StatusCode: resp.StatusCode, + Headers: resp.Header, + BodyText: "aborted_after_first_chunk", + DurationMS: time.Since(start).Milliseconds(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + return nil +} + +func (r *Runner) caseToolcallStreamMixed(ctx context.Context, cc *caseContext) error { + payload := toolcallPayload(true) + payload["messages"] = []map[string]any{ + { + "role": "user", + "content": "请先输出一句普通文本,再调用工具 search 查询 golang,最后再输出一句普通文本。", + }, + } + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: payload, + Stream: true, + Retryable: false, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + hasTool := false + hasText := false + rawLeak := false + for _, f := range frames { + choices, _ := f["choices"].([]any) + for _, c := range choices { + ch, _ := c.(map[string]any) + delta, _ := ch["delta"].(map[string]any) + if _, ok := delta["tool_calls"]; ok { + hasTool = true + } + content := asString(delta["content"]) + if content != "" { + hasText = true + } + if strings.Contains(strings.ToLower(content), `"tool_calls"`) { + rawLeak = true + } + } + } + cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing") + cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls leaked") + cc.assert("done_terminated", done, "expected [DONE]") + if !(hasTool && hasText) { + r.warnings = append(r.warnings, "toolcall mixed stream did not produce both text and tool_calls in this run (model-side behavior dependent)") + } + return nil +} + +func (r *Runner) caseSSEJSONIntegrity(ctx context.Context, cc *caseContext) error { + openaiResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "输出一句话"}, + }, + "stream": true, + }, + Stream: true, + Retryable: false, + }) + if err != nil { + return err + } + cc.assert("openai_status_200", openaiResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", openaiResp.StatusCode)) + badOpenAI := countMalformedSSEJSONLines(openaiResp.Body) + cc.assert("openai_sse_json_valid", badOpenAI == 0, fmt.Sprintf("malformed=%d", badOpenAI)) + + anthropicResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-20250514", + "messages": []map[string]any{ + {"role": "user", "content": "stream json integrity"}, + }, + "stream": true, + }, + Stream: true, + Retryable: false, + }) + if err != nil { + return err + } + cc.assert("anthropic_status_200", anthropicResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", anthropicResp.StatusCode)) + badAnthropic := countMalformedSSEJSONLines(anthropicResp.Body) + cc.assert("anthropic_sse_json_valid", badAnthropic == 0, fmt.Sprintf("malformed=%d", badAnthropic)) + return nil +} + +func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error { + resp, err := cc.requestOnce(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-not-exists", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "stream": false, + }, + Retryable: false, + }, 1) + if err != nil { + return err + } + cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) + return nil +} + +func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error { + if len(r.configRaw.Accounts) == 0 { + cc.assert("account_present", false, "no account in config") + return nil + } + acc := r.configRaw.Accounts[0] + id := strings.TrimSpace(acc.Email) + if id == "" { + id = strings.TrimSpace(acc.Mobile) + } + if id == "" { + cc.assert("account_identifier", false, "first account has no identifier") + return nil + } + if strings.TrimSpace(acc.Password) == "" { + r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty") + cc.assert("account_password_present", true, "skipped strict refresh check due empty password") + return nil + } + invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID) + update := map[string]any{ + "keys": r.configRaw.Keys, + "accounts": []map[string]any{ + { + "email": acc.Email, + "mobile": acc.Mobile, + "password": acc.Password, + "token": invalidToken, + }, + }, + } + updResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: update, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode)) + + chatResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "X-Ds2-Target-Account": id, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "token refresh test"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body))) + + cfgResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + var cfg map[string]any + _ = json.Unmarshal(cfgResp.Body, &cfg) + accounts, _ := cfg["accounts"].([]any) + preview := "" + hasToken := false + for _, item := range accounts { + m, _ := item.(map[string]any) + e := asString(m["email"]) + mo := asString(m["mobile"]) + if e == acc.Email && mo == acc.Mobile { + preview = asString(m["token_preview"]) + hasToken, _ = m["has_token"].(bool) + break + } + } + cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body))) + cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20])) + return nil +} + +func (r *Runner) fetchQueueStatus(ctx context.Context, cc *caseContext) (map[string]any, error) { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/queue/status", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return nil, err + } + var m map[string]any + if err := json.Unmarshal(resp.Body, &m); err != nil { + return nil, err + } + return m, nil +} + +func countMalformedSSEJSONLines(body []byte) int { + lines := strings.Split(string(body), "\n") + bad := 0 + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" || payload == "[DONE]" { + continue + } + var v any + if err := json.Unmarshal([]byte(payload), &v); err != nil { + bad++ + } + } + return bad +} diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go index 3a0f297..20fbdcc 100644 --- a/internal/testsuite/runner.go +++ b/internal/testsuite/runner.go @@ -89,6 +89,7 @@ type caseContext struct { id string dir string startedAt time.Time + mu sync.Mutex seq int assertions []assertionResult requests []requestLog @@ -144,8 +145,10 @@ type Runner struct { type runConfig struct { Keys []string `json:"keys"` Accounts []struct { - Email string `json:"email,omitempty"` - Mobile string `json:"mobile,omitempty"` + Email string `json:"email,omitempty"` + Mobile string `json:"mobile,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` } `json:"accounts"` } @@ -500,6 +503,8 @@ func (r *Runner) runCase(ctx context.Context, c caseDef) { } func (cc *caseContext) assert(name string, ok bool, detail string) { + cc.mu.Lock() + defer cc.mu.Unlock() cc.assertions = append(cc.assertions, assertionResult{ Name: name, Passed: ok, @@ -532,9 +537,12 @@ func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*response } func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) { + cc.mu.Lock() cc.seq++ - traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq) + seq := cc.seq + traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq) cc.traceIDsSet[traceID] = struct{}{} + cc.mu.Unlock() fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) if err != nil { @@ -558,8 +566,9 @@ func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attemp bodyAny = spec.Body headers["Content-Type"] = "application/json" } + cc.mu.Lock() cc.requests = append(cc.requests, requestLog{ - Seq: cc.seq, + Seq: seq, Attempt: attempt, TraceID: traceID, Method: spec.Method, @@ -568,6 +577,7 @@ func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attemp Body: bodyAny, Timestamp: time.Now().Format(time.RFC3339Nano), }) + cc.mu.Unlock() reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) defer cancel() @@ -581,8 +591,9 @@ func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attemp start := time.Now() resp, err := cc.runner.httpClient.Do(req) if err != nil { + cc.mu.Lock() cc.responses = append(cc.responses, responseLog{ - Seq: cc.seq, + Seq: seq, Attempt: attempt, TraceID: traceID, StatusCode: 0, @@ -590,13 +601,15 @@ func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attemp NetworkErr: err.Error(), ReceivedAt: time.Now().Format(time.RFC3339Nano), }) + cc.mu.Unlock() return nil, err } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) + cc.mu.Lock() cc.responses = append(cc.responses, responseLog{ - Seq: cc.seq, + Seq: seq, Attempt: attempt, TraceID: traceID, StatusCode: resp.StatusCode, @@ -611,6 +624,7 @@ func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attemp cc.streamRaw.Write(body) cc.streamRaw.WriteString("\n\n") } + cc.mu.Unlock() return &responseResult{ StatusCode: resp.StatusCode, @@ -700,7 +714,15 @@ func (r *Runner) cases() []caseDef { {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens}, {ID: "admin_account_test_single", Run: r.caseAdminAccountTest}, {ID: "concurrency_burst", Run: r.caseConcurrencyBurst}, + {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit}, + {ID: "stream_abort_release", Run: r.caseStreamAbortRelease}, + {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed}, + {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity}, + {ID: "error_contract_invalid_model", Run: r.caseInvalidModel}, + {ID: "error_contract_missing_messages", Run: r.caseMissingMessages}, + {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized}, {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated}, + {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount}, {ID: "error_contract_invalid_key", Run: r.caseInvalidKey}, } }