mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 16:35:27 +08:00
refactor: centralize SSE stream parsing logic into a new sse package and update the PoW solver to honor context cancellation during module acquisition.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -220,20 +219,6 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
lines := make(chan []byte, 128)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
b := append([]byte{}, scanner.Bytes()...)
|
||||
lines <- b
|
||||
}
|
||||
close(lines)
|
||||
done <- scanner.Err()
|
||||
}()
|
||||
|
||||
send := func(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("event: "))
|
||||
@@ -276,10 +261,11 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
},
|
||||
})
|
||||
|
||||
currentType := "text"
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
hasContent := false
|
||||
lastContent := time.Now()
|
||||
@@ -412,7 +398,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
send("ping", map[string]any{"type": "ping"})
|
||||
case line, ok := <-lines:
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
if err := <-done; err != nil {
|
||||
sendError(err.Error())
|
||||
@@ -421,9 +407,6 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
|
||||
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = parsed.NextType
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -189,29 +188,16 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
lines := make(chan []byte, 128)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
b := append([]byte{}, scanner.Bytes()...)
|
||||
lines <- b
|
||||
}
|
||||
close(lines)
|
||||
done <- scanner.Err()
|
||||
}()
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
var toolSieve toolStreamSieveState
|
||||
toolCallsEmitted := false
|
||||
currentType := "text"
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
lastContent := time.Now()
|
||||
@@ -321,7 +307,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
_ = rc.Flush()
|
||||
}
|
||||
case line, ok := <-lines:
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
// Ensure scanner completion is observed only after all queued
|
||||
// SSE lines are drained, avoiding early finalize races.
|
||||
@@ -329,8 +315,6 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = parsed.NextType
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -167,12 +167,15 @@ func (p *PowSolver) createModule(ctx context.Context) (*pooledModule, error) {
|
||||
|
||||
func (p *PowSolver) acquireModule(ctx context.Context) (*pooledModule, error) {
|
||||
if p.pool != nil {
|
||||
select {
|
||||
case pm := <-p.pool:
|
||||
if pm != nil {
|
||||
return pm, nil
|
||||
for {
|
||||
select {
|
||||
case pm := <-p.pool:
|
||||
if pm != nil {
|
||||
return pm, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
return p.createModule(ctx)
|
||||
|
||||
@@ -3,6 +3,7 @@ package deepseek
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPowPoolSizeFromEnv(t *testing.T) {
|
||||
@@ -35,6 +36,26 @@ func TestPowSolverAcquireReleaseReusesModule(t *testing.T) {
|
||||
solver.releaseModule(pm2)
|
||||
}
|
||||
|
||||
func TestPowSolverAcquireHonorsContextWhenPoolExhausted(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "1")
|
||||
solver := NewPowSolver("missing-file.wasm")
|
||||
if err := solver.init(context.Background()); err != nil {
|
||||
t.Fatalf("init failed: %v", err)
|
||||
}
|
||||
|
||||
held, err := solver.acquireModule(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("acquire held module failed: %v", err)
|
||||
}
|
||||
defer solver.releaseModule(held)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
if _, err := solver.acquireModule(ctx); err == nil {
|
||||
t.Fatalf("expected context cancellation while pool is exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPreloadPowUsesClientSolver(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "1")
|
||||
client := NewClient(nil, nil)
|
||||
|
||||
40
internal/sse/stream.go
Normal file
40
internal/sse/stream.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
parsedLineBufferSize = 128
|
||||
scannerBufferSize = 64 * 1024
|
||||
maxScannerLineSize = 2 * 1024 * 1024
|
||||
)
|
||||
|
||||
// StartParsedLinePump scans an upstream DeepSeek SSE body and emits normalized
|
||||
// line parse results. It centralizes scanner setup + current fragment type
|
||||
// tracking for all streaming adapters.
|
||||
func StartParsedLinePump(ctx context.Context, body io.Reader, thinkingEnabled bool, initialType string) (<-chan LineResult, <-chan error) {
|
||||
out := make(chan LineResult, parsedLineBufferSize)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 0, scannerBufferSize), maxScannerLineSize)
|
||||
currentType := initialType
|
||||
for scanner.Scan() {
|
||||
line := append([]byte{}, scanner.Bytes()...)
|
||||
result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = result.NextType
|
||||
select {
|
||||
case out <- result:
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
done <- scanner.Err()
|
||||
}()
|
||||
return out, done
|
||||
}
|
||||
30
internal/sse/stream_test.go
Normal file
30
internal/sse/stream_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStartParsedLinePumpParsesAndStops(t *testing.T) {
|
||||
body := strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hi\"}\n\ndata: [DONE]\n")
|
||||
results, done := StartParsedLinePump(context.Background(), body, false, "text")
|
||||
|
||||
collected := make([]LineResult, 0, 2)
|
||||
for r := range results {
|
||||
collected = append(collected, r)
|
||||
}
|
||||
if err := <-done; err != nil {
|
||||
t.Fatalf("unexpected scanner error: %v", err)
|
||||
}
|
||||
if len(collected) < 2 {
|
||||
t.Fatalf("expected at least 2 parsed results, got %d", len(collected))
|
||||
}
|
||||
if !collected[0].Parsed || len(collected[0].Parts) == 0 {
|
||||
t.Fatalf("expected first line to contain parsed content")
|
||||
}
|
||||
last := collected[len(collected)-1]
|
||||
if !last.Parsed || !last.Stop {
|
||||
t.Fatalf("expected last line to stop stream, got parsed=%v stop=%v", last.Parsed, last.Stop)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user