From 07de35a093de3233f6a8d1e0c05255bd26075ef2 Mon Sep 17 00:00:00 2001 From: CJACK Date: Tue, 17 Feb 2026 04:40:01 +0800 Subject: [PATCH] refactor: centralize SSE stream parsing logic into a new `sse` package and update the PoW solver to honor context cancellation during module acquisition. --- internal/adapter/claude/handler.go | 25 +++---------------- internal/adapter/openai/handler.go | 24 +++--------------- internal/deepseek/pow.go | 13 ++++++---- internal/deepseek/pow_test.go | 21 ++++++++++++++++ internal/sse/stream.go | 40 ++++++++++++++++++++++++++++++ internal/sse/stream_test.go | 30 ++++++++++++++++++++++ 6 files changed, 107 insertions(+), 46 deletions(-) create mode 100644 internal/sse/stream.go create mode 100644 internal/sse/stream_test.go diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index 08b8e85..b9ecd27 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -1,7 +1,6 @@ package claude import ( - "bufio" "encoding/json" "fmt" "io" @@ -220,20 +219,6 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } - lines := make(chan []byte, 128) - done := make(chan error, 1) - go func() { - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - b := append([]byte{}, scanner.Bytes()...) - lines <- b - } - close(lines) - done <- scanner.Err() - }() - send := func(event string, v any) { b, _ := json.Marshal(v) _, _ = w.Write([]byte("event: ")) @@ -276,10 +261,11 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ }, }) - currentType := "text" + initialType := "text" if thinkingEnabled { - currentType = "thinking" + initialType = "thinking" } + parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) bufferToolContent := len(toolNames) > 0 hasContent := false lastContent := time.Now() @@ -412,7 +398,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ return } send("ping", map[string]any{"type": "ping"}) - case line, ok := <-lines: + case parsed, ok := <-parsedLines: if !ok { if err := <-done; err != nil { sendError(err.Error()) @@ -421,9 +407,6 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ finalize("end_turn") return } - - parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType) - currentType = parsed.NextType if !parsed.Parsed { continue } diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 0df6d11..d0a2f1d 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -1,7 +1,6 @@ package openai import ( - "bufio" "context" "encoding/json" "fmt" @@ -189,29 +188,16 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") } - lines := make(chan []byte, 128) - done := make(chan error, 1) - go func() { - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - b := append([]byte{}, scanner.Bytes()...) - lines <- b - } - close(lines) - done <- scanner.Err() - }() - created := time.Now().Unix() firstChunkSent := false bufferToolContent := len(toolNames) > 0 var toolSieve toolStreamSieveState toolCallsEmitted := false - currentType := "text" + initialType := "text" if thinkingEnabled { - currentType = "thinking" + initialType = "thinking" } + parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) thinking := strings.Builder{} text := strings.Builder{} lastContent := time.Now() @@ -321,7 +307,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt _, _ = w.Write([]byte(": keep-alive\n\n")) _ = rc.Flush() } - case line, ok := <-lines: + case parsed, ok := <-parsedLines: if !ok { // Ensure scanner completion is observed only after all queued // SSE lines are drained, avoiding early finalize races. @@ -329,8 +315,6 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt finalize("stop") return } - parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType) - currentType = parsed.NextType if !parsed.Parsed { continue } diff --git a/internal/deepseek/pow.go b/internal/deepseek/pow.go index 5dda8cf..95d86b8 100644 --- a/internal/deepseek/pow.go +++ b/internal/deepseek/pow.go @@ -167,12 +167,15 @@ func (p *PowSolver) createModule(ctx context.Context) (*pooledModule, error) { func (p *PowSolver) acquireModule(ctx context.Context) (*pooledModule, error) { if p.pool != nil { - select { - case pm := <-p.pool: - if pm != nil { - return pm, nil + for { + select { + case pm := <-p.pool: + if pm != nil { + return pm, nil + } + case <-ctx.Done(): + return nil, ctx.Err() } - default: } } return p.createModule(ctx) diff --git a/internal/deepseek/pow_test.go b/internal/deepseek/pow_test.go index 3e8af6c..6ebcd2a 100644 --- a/internal/deepseek/pow_test.go +++ b/internal/deepseek/pow_test.go @@ -3,6 +3,7 @@ package deepseek import ( "context" "testing" + "time" ) func TestPowPoolSizeFromEnv(t *testing.T) { @@ -35,6 +36,26 @@ func TestPowSolverAcquireReleaseReusesModule(t *testing.T) { solver.releaseModule(pm2) } +func TestPowSolverAcquireHonorsContextWhenPoolExhausted(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + solver := NewPowSolver("missing-file.wasm") + if err := solver.init(context.Background()); err != nil { + t.Fatalf("init failed: %v", err) + } + + held, err := solver.acquireModule(context.Background()) + if err != nil { + t.Fatalf("acquire held module failed: %v", err) + } + defer solver.releaseModule(held) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if _, err := solver.acquireModule(ctx); err == nil { + t.Fatalf("expected context cancellation while pool is exhausted") + } +} + func TestClientPreloadPowUsesClientSolver(t *testing.T) { t.Setenv("DS2API_POW_POOL_SIZE", "1") client := NewClient(nil, nil) diff --git a/internal/sse/stream.go b/internal/sse/stream.go new file mode 100644 index 0000000..4aa2d39 --- /dev/null +++ b/internal/sse/stream.go @@ -0,0 +1,40 @@ +package sse + +import ( + "bufio" + "context" + "io" +) + +const ( + parsedLineBufferSize = 128 + scannerBufferSize = 64 * 1024 + maxScannerLineSize = 2 * 1024 * 1024 +) + +// StartParsedLinePump scans an upstream DeepSeek SSE body and emits normalized +// line parse results. It centralizes scanner setup + current fragment type +// tracking for all streaming adapters. +func StartParsedLinePump(ctx context.Context, body io.Reader, thinkingEnabled bool, initialType string) (<-chan LineResult, <-chan error) { + out := make(chan LineResult, parsedLineBufferSize) + done := make(chan error, 1) + go func() { + defer close(out) + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 0, scannerBufferSize), maxScannerLineSize) + currentType := initialType + for scanner.Scan() { + line := append([]byte{}, scanner.Bytes()...) + result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType) + currentType = result.NextType + select { + case out <- result: + case <-ctx.Done(): + done <- ctx.Err() + return + } + } + done <- scanner.Err() + }() + return out, done +} diff --git a/internal/sse/stream_test.go b/internal/sse/stream_test.go new file mode 100644 index 0000000..a4fd2bb --- /dev/null +++ b/internal/sse/stream_test.go @@ -0,0 +1,30 @@ +package sse + +import ( + "context" + "strings" + "testing" +) + +func TestStartParsedLinePumpParsesAndStops(t *testing.T) { + body := strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hi\"}\n\ndata: [DONE]\n") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0, 2) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected scanner error: %v", err) + } + if len(collected) < 2 { + t.Fatalf("expected at least 2 parsed results, got %d", len(collected)) + } + if !collected[0].Parsed || len(collected[0].Parts) == 0 { + t.Fatalf("expected first line to contain parsed content") + } + last := collected[len(collected)-1] + if !last.Parsed || !last.Stop { + t.Fatalf("expected last line to stop stream, got parsed=%v stop=%v", last.Parsed, last.Stop) + } +}