mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 17:05:32 +08:00
feat: implement automatic completion continuation for incomplete DeepSeek responses
This commit is contained in:
@@ -31,6 +31,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
resp = c.wrapCompletionWithAutoContinue(ctx, a, payload, resp)
|
||||
return resp, nil
|
||||
}
|
||||
if captureSession != nil {
|
||||
|
||||
240
internal/deepseek/client_continue.go
Normal file
240
internal/deepseek/client_continue.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
const defaultAutoContinueLimit = 8
|
||||
|
||||
type continueOpenFunc func(context.Context, string, int) (*http.Response, error)
|
||||
|
||||
type continueState struct {
|
||||
sessionID string
|
||||
responseMessageID int
|
||||
lastStatus string
|
||||
finished bool
|
||||
}
|
||||
|
||||
// wrapCompletionWithAutoContinue wraps the completion response body so that
|
||||
// if the upstream indicates the response is incomplete (WIP / INCOMPLETE /
|
||||
// AUTO_CONTINUE), ds2api will automatically call the DeepSeek continue
|
||||
// endpoint and splice the continuation SSE stream onto the original.
|
||||
// The caller sees a single, seamless SSE stream.
|
||||
func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, resp *http.Response) *http.Response {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return resp
|
||||
}
|
||||
sessionID, _ := payload["chat_session_id"].(string)
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return resp
|
||||
}
|
||||
config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID)
|
||||
resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) {
|
||||
return c.callContinue(ctx, a, sessionID, responseMessageID)
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
// callContinue sends a continue request to DeepSeek to resume generation.
|
||||
func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int) (*http.Response, error) {
|
||||
if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 {
|
||||
return nil, errors.New("missing continue identifiers")
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"message_id": responseMessageID,
|
||||
"fallback_to_resume": true,
|
||||
}
|
||||
config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID)
|
||||
captureSession := c.capture.Start("deepseek_continue", DeepSeekContinueURL, a.AccountID, payload)
|
||||
resp, err := c.streamPost(ctx, DeepSeekContinueURL, headers, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
return nil, errors.New("continue failed")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// newAutoContinueBody returns a new ReadCloser that transparently pumps
|
||||
// continuation rounds via an io.Pipe.
|
||||
func newAutoContinueBody(ctx context.Context, initial io.ReadCloser, sessionID string, maxRounds int, openContinue continueOpenFunc) io.ReadCloser {
|
||||
if initial == nil || strings.TrimSpace(sessionID) == "" || openContinue == nil {
|
||||
return initial
|
||||
}
|
||||
if maxRounds <= 0 {
|
||||
maxRounds = defaultAutoContinueLimit
|
||||
}
|
||||
pr, pw := io.Pipe()
|
||||
go pumpAutoContinue(ctx, pw, initial, continueState{sessionID: sessionID}, maxRounds, openContinue)
|
||||
return pr
|
||||
}
|
||||
|
||||
// pumpAutoContinue is the goroutine that drives the auto-continue loop.
|
||||
// It reads the initial SSE body, checks whether a continue is required,
|
||||
// and if so opens a new continue stream and splices it onto the pipe writer.
|
||||
func pumpAutoContinue(ctx context.Context, pw *io.PipeWriter, initial io.ReadCloser, state continueState, maxRounds int, openContinue continueOpenFunc) {
|
||||
defer func() { _ = pw.Close() }()
|
||||
current := initial
|
||||
rounds := 0
|
||||
for {
|
||||
hadDone, err := streamBodyWithContinueState(ctx, pw, current, &state)
|
||||
_ = current.Close()
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
if state.shouldContinue() && rounds < maxRounds {
|
||||
rounds++
|
||||
config.Logger.Info("[auto_continue] continuing", "round", rounds, "session_id", state.sessionID, "message_id", state.responseMessageID, "status", state.lastStatus)
|
||||
nextResp, err := openContinue(ctx, state.sessionID, state.responseMessageID)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[auto_continue] continue request failed", "round", rounds, "error", err)
|
||||
_ = pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
current = nextResp.Body
|
||||
state.prepareForNextRound()
|
||||
continue
|
||||
}
|
||||
// Emit the final [DONE] sentinel if the upstream had one.
|
||||
if hadDone {
|
||||
if _, err := io.Copy(pw, bytes.NewBufferString("data: [DONE]\n")); err != nil {
|
||||
_ = pw.CloseWithError(err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// streamBodyWithContinueState scans an SSE body line-by-line, writing each
|
||||
// line through to pw while observing state signals. Intermediate [DONE]
|
||||
// sentinels are consumed (not forwarded) so that the downstream only sees
|
||||
// one final [DONE] at the very end.
|
||||
func streamBodyWithContinueState(ctx context.Context, pw *io.PipeWriter, body io.Reader, state *continueState) (bool, error) {
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
hadDone := false
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return hadDone, ctx.Err()
|
||||
default:
|
||||
}
|
||||
line := append([]byte{}, scanner.Bytes()...)
|
||||
trimmed := strings.TrimSpace(string(line))
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if data == "[DONE]" {
|
||||
hadDone = true
|
||||
continue
|
||||
}
|
||||
state.observe(data)
|
||||
}
|
||||
if _, err := io.Copy(pw, bytes.NewReader(append(line, '\n'))); err != nil {
|
||||
return hadDone, err
|
||||
}
|
||||
}
|
||||
return hadDone, scanner.Err()
|
||||
}
|
||||
|
||||
// observe extracts continue-relevant signals from an SSE JSON chunk.
|
||||
func (s *continueState) observe(data string) {
|
||||
if s == nil || strings.TrimSpace(data) == "" {
|
||||
return
|
||||
}
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return
|
||||
}
|
||||
// Top-level response_message_id
|
||||
if id := intFrom(chunk["response_message_id"]); id > 0 {
|
||||
s.responseMessageID = id
|
||||
}
|
||||
// Path-based status: {"p": "response/status", "v": "FINISHED"}
|
||||
if p, _ := chunk["p"].(string); p == "response/status" {
|
||||
if status, _ := chunk["v"].(string); status != "" {
|
||||
s.lastStatus = strings.TrimSpace(status)
|
||||
if strings.EqualFold(s.lastStatus, "FINISHED") {
|
||||
s.finished = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Nested v.response
|
||||
v, _ := chunk["v"].(map[string]any)
|
||||
if response, _ := v["response"].(map[string]any); response != nil {
|
||||
if id := intFrom(response["message_id"]); id > 0 {
|
||||
s.responseMessageID = id
|
||||
}
|
||||
if status, _ := response["status"].(string); status != "" {
|
||||
s.lastStatus = strings.TrimSpace(status)
|
||||
if strings.EqualFold(s.lastStatus, "FINISHED") {
|
||||
s.finished = true
|
||||
}
|
||||
}
|
||||
if autoContinue, ok := response["auto_continue"].(bool); ok && autoContinue {
|
||||
s.lastStatus = "AUTO_CONTINUE"
|
||||
}
|
||||
}
|
||||
// Nested message.response
|
||||
if message, _ := chunk["message"].(map[string]any); message != nil {
|
||||
if response, _ := message["response"].(map[string]any); response != nil {
|
||||
if id := intFrom(response["message_id"]); id > 0 {
|
||||
s.responseMessageID = id
|
||||
}
|
||||
if status, _ := response["status"].(string); status != "" {
|
||||
s.lastStatus = strings.TrimSpace(status)
|
||||
if strings.EqualFold(s.lastStatus, "FINISHED") {
|
||||
s.finished = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldContinue returns true when the upstream indicates the response is
|
||||
// not yet finished and we have enough information to issue a continue request.
|
||||
func (s *continueState) shouldContinue() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.finished || s.responseMessageID <= 0 || strings.TrimSpace(s.sessionID) == "" {
|
||||
return false
|
||||
}
|
||||
switch strings.ToUpper(strings.TrimSpace(s.lastStatus)) {
|
||||
case "WIP", "INCOMPLETE", "AUTO_CONTINUE":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// prepareForNextRound resets ephemeral state before processing the next
|
||||
// continuation stream.
|
||||
func (s *continueState) prepareForNextRound() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.finished = false
|
||||
s.lastStatus = ""
|
||||
}
|
||||
@@ -11,6 +11,7 @@ const (
|
||||
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
|
||||
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||
DeepSeekContinueURL = "https://chat.deepseek.com/api/v0/chat/continue"
|
||||
DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page"
|
||||
DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete"
|
||||
DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"
|
||||
|
||||
Reference in New Issue
Block a user