mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
Replace bufio.Scanner with bufio.NewReaderSize + ReadBytes('\n') across all
SSE read paths to preserve long single-line data (e.g. write_file content).
Add quasi_status and auto_continue handling as direct path-based patches in
both Go continue observer and Node vercel_stream_impl, mirroring existing
batch-patch logic. Add 2MiB+ line throughput tests at every SSE layer.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
330 lines
9.6 KiB
Go
330 lines
9.6 KiB
Go
package client
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
dsprotocol "ds2api/internal/deepseek/protocol"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/config"
|
|
)
|
|
|
|
const defaultAutoContinueLimit = 8
|
|
|
|
type continueOpenFunc func(context.Context, string, int) (*http.Response, error)
|
|
|
|
type continueState struct {
|
|
sessionID string
|
|
responseMessageID int
|
|
lastStatus string
|
|
finished bool
|
|
}
|
|
|
|
// wrapCompletionWithAutoContinue wraps the completion response body so that
|
|
// if the upstream indicates the response is incomplete (INCOMPLETE /
|
|
// AUTO_CONTINUE), ds2api will automatically call the DeepSeek continue
|
|
// endpoint and splice the continuation SSE stream onto the original.
|
|
// The caller sees a single, seamless SSE stream.
|
|
func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, resp *http.Response) *http.Response {
|
|
if resp == nil || resp.Body == nil {
|
|
return resp
|
|
}
|
|
sessionID, _ := payload["chat_session_id"].(string)
|
|
sessionID = strings.TrimSpace(sessionID)
|
|
if sessionID == "" {
|
|
return resp
|
|
}
|
|
config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID)
|
|
resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) {
|
|
return c.callContinue(ctx, a, sessionID, responseMessageID, powResp)
|
|
})
|
|
return resp
|
|
}
|
|
|
|
// callContinue sends a continue request to DeepSeek to resume generation.
|
|
func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int, powResp string) (*http.Response, error) {
|
|
if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 {
|
|
return nil, errors.New("missing continue identifiers")
|
|
}
|
|
clients := c.requestClientsForAuth(ctx, a)
|
|
headers := c.authHeaders(a.DeepSeekToken)
|
|
headers["x-ds-pow-response"] = powResp
|
|
payload := map[string]any{
|
|
"chat_session_id": sessionID,
|
|
"message_id": responseMessageID,
|
|
"fallback_to_resume": true,
|
|
}
|
|
config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID)
|
|
captureSession := c.capture.Start("deepseek_continue", dsprotocol.DeepSeekContinueURL, a.AccountID, payload)
|
|
resp, err := c.streamPost(ctx, clients.stream, dsprotocol.DeepSeekContinueURL, headers, payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if captureSession != nil {
|
|
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
_ = resp.Body.Close()
|
|
return nil, errors.New("continue failed")
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// newAutoContinueBody returns a new ReadCloser that transparently pumps
|
|
// continuation rounds via an io.Pipe.
|
|
func newAutoContinueBody(ctx context.Context, initial io.ReadCloser, sessionID string, maxRounds int, openContinue continueOpenFunc) io.ReadCloser {
|
|
if initial == nil || strings.TrimSpace(sessionID) == "" || openContinue == nil {
|
|
return initial
|
|
}
|
|
if maxRounds <= 0 {
|
|
maxRounds = defaultAutoContinueLimit
|
|
}
|
|
pr, pw := io.Pipe()
|
|
go pumpAutoContinue(ctx, pw, initial, continueState{sessionID: sessionID}, maxRounds, openContinue)
|
|
return pr
|
|
}
|
|
|
|
// pumpAutoContinue is the goroutine that drives the auto-continue loop.
|
|
// It reads the initial SSE body, checks whether a continue is required,
|
|
// and if so opens a new continue stream and splices it onto the pipe writer.
|
|
func pumpAutoContinue(ctx context.Context, pw *io.PipeWriter, initial io.ReadCloser, state continueState, maxRounds int, openContinue continueOpenFunc) {
|
|
defer func() { _ = pw.Close() }()
|
|
current := initial
|
|
rounds := 0
|
|
for {
|
|
hadDone, err := streamBodyWithContinueState(ctx, pw, current, &state)
|
|
_ = current.Close()
|
|
if err != nil {
|
|
_ = pw.CloseWithError(err)
|
|
return
|
|
}
|
|
if state.shouldContinue() && rounds < maxRounds {
|
|
rounds++
|
|
config.Logger.Info("[auto_continue] continuing", "round", rounds, "session_id", state.sessionID, "message_id", state.responseMessageID, "status", state.lastStatus)
|
|
nextResp, err := openContinue(ctx, state.sessionID, state.responseMessageID)
|
|
if err != nil {
|
|
config.Logger.Warn("[auto_continue] continue request failed", "round", rounds, "error", err)
|
|
_ = pw.CloseWithError(err)
|
|
return
|
|
}
|
|
current = nextResp.Body
|
|
state.prepareForNextRound()
|
|
continue
|
|
}
|
|
// Emit the final [DONE] sentinel if the upstream had one.
|
|
if hadDone {
|
|
if _, err := io.Copy(pw, bytes.NewBufferString("data: [DONE]\n")); err != nil {
|
|
_ = pw.CloseWithError(err)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// streamBodyWithContinueState scans an SSE body line-by-line, writing each
|
|
// line through to pw while observing state signals. Intermediate [DONE]
|
|
// sentinels are consumed (not forwarded) so that the downstream only sees
|
|
// one final [DONE] at the very end.
|
|
func streamBodyWithContinueState(ctx context.Context, pw *io.PipeWriter, body io.Reader, state *continueState) (bool, error) {
|
|
reader := bufio.NewReaderSize(body, 64*1024)
|
|
hadDone := false
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return hadDone, ctx.Err()
|
|
default:
|
|
}
|
|
line, err := reader.ReadBytes('\n')
|
|
if len(line) == 0 && err != nil {
|
|
if err == io.EOF {
|
|
return hadDone, nil
|
|
}
|
|
return hadDone, err
|
|
}
|
|
trimmed := strings.TrimSpace(string(line))
|
|
if trimmed != "" {
|
|
if strings.HasPrefix(trimmed, "data:") {
|
|
data := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
|
if data == "[DONE]" {
|
|
hadDone = true
|
|
if err != nil && err != io.EOF {
|
|
return hadDone, err
|
|
}
|
|
if err == io.EOF {
|
|
return hadDone, nil
|
|
}
|
|
continue
|
|
}
|
|
state.observe(data)
|
|
}
|
|
if !strings.HasSuffix(string(line), "\n") {
|
|
line = append(line, '\n')
|
|
}
|
|
if _, copyErr := io.Copy(pw, bytes.NewReader(line)); copyErr != nil {
|
|
return hadDone, copyErr
|
|
}
|
|
}
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return hadDone, nil
|
|
}
|
|
return hadDone, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// observe extracts continue-relevant signals from an SSE JSON chunk.
|
|
func (s *continueState) observe(data string) {
|
|
if s == nil || strings.TrimSpace(data) == "" {
|
|
return
|
|
}
|
|
var chunk map[string]any
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
return
|
|
}
|
|
// Top-level response_message_id
|
|
if id := intFrom(chunk["response_message_id"]); id > 0 {
|
|
s.responseMessageID = id
|
|
}
|
|
s.observeDirectPatch(asString(chunk["p"]), chunk["v"])
|
|
if p, _ := chunk["p"].(string); p == "response" {
|
|
s.observeBatchPatches("response", chunk["v"])
|
|
} else {
|
|
s.observeBatchPatches("", chunk["v"])
|
|
}
|
|
if v, _ := chunk["v"].(map[string]any); v != nil {
|
|
s.observeResponseObject(v["response"])
|
|
}
|
|
if message, _ := chunk["message"].(map[string]any); message != nil {
|
|
s.observeResponseObject(message["response"])
|
|
}
|
|
}
|
|
|
|
func (s *continueState) observeDirectPatch(path string, value any) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
switch strings.Trim(strings.TrimSpace(path), "/") {
|
|
case "response/status", "status", "response/quasi_status", "quasi_status":
|
|
s.setStatus(asString(value))
|
|
case "response/auto_continue", "auto_continue":
|
|
if v, ok := value.(bool); ok && v {
|
|
s.lastStatus = "AUTO_CONTINUE"
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *continueState) observeResponseObject(raw any) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
response, _ := raw.(map[string]any)
|
|
if response == nil {
|
|
return
|
|
}
|
|
if id := intFrom(response["message_id"]); id > 0 {
|
|
s.responseMessageID = id
|
|
}
|
|
s.setStatus(asString(response["status"]))
|
|
if autoContinue, ok := response["auto_continue"].(bool); ok && autoContinue {
|
|
s.lastStatus = "AUTO_CONTINUE"
|
|
}
|
|
}
|
|
|
|
func (s *continueState) observeBatchPatches(parentPath string, raw any) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
patches, ok := raw.([]any)
|
|
if !ok {
|
|
return
|
|
}
|
|
for _, patch := range patches {
|
|
m, ok := patch.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
path := strings.TrimSpace(asString(m["p"]))
|
|
if path == "" {
|
|
continue
|
|
}
|
|
fullPath := path
|
|
if parent := strings.Trim(strings.TrimSpace(parentPath), "/"); parent != "" && !strings.Contains(path, "/") {
|
|
fullPath = parent + "/" + path
|
|
}
|
|
switch strings.Trim(strings.TrimSpace(fullPath), "/") {
|
|
case "response/status", "status", "response/quasi_status", "quasi_status":
|
|
s.setStatus(asString(m["v"]))
|
|
case "response/auto_continue", "auto_continue":
|
|
if v, ok := m["v"].(bool); ok && v {
|
|
s.lastStatus = "AUTO_CONTINUE"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *continueState) setStatus(status string) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
normalized := strings.TrimSpace(status)
|
|
if normalized == "" {
|
|
return
|
|
}
|
|
s.lastStatus = normalized
|
|
if strings.EqualFold(normalized, "FINISHED") || strings.EqualFold(normalized, "CONTENT_FILTER") {
|
|
s.finished = true
|
|
}
|
|
}
|
|
|
|
// shouldContinue returns true when the upstream explicitly indicates the
|
|
// response is incomplete and we have enough information to issue a continue
|
|
// request. Plain WIP is not sufficient because normal streams begin in WIP.
|
|
func (s *continueState) shouldContinue() bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
if s.finished || s.responseMessageID <= 0 || strings.TrimSpace(s.sessionID) == "" {
|
|
return false
|
|
}
|
|
switch strings.ToUpper(strings.TrimSpace(s.lastStatus)) {
|
|
case "INCOMPLETE", "AUTO_CONTINUE":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// prepareForNextRound resets ephemeral state before processing the next
|
|
// continuation stream.
|
|
func (s *continueState) prepareForNextRound() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.finished = false
|
|
s.lastStatus = ""
|
|
}
|
|
|
|
func asString(v any) string {
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
switch x := v.(type) {
|
|
case string:
|
|
return x
|
|
default:
|
|
s := strings.TrimSpace(strings.ReplaceAll(strings.TrimSpace(fmt.Sprint(v)), "\u0000", ""))
|
|
if s == "<nil>" {
|
|
return ""
|
|
}
|
|
return s
|
|
}
|
|
}
|