diff --git a/internal/deepseek/client_completion.go b/internal/deepseek/client_completion.go index 051bffe..d9b65b8 100644 --- a/internal/deepseek/client_completion.go +++ b/internal/deepseek/client_completion.go @@ -31,6 +31,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa if captureSession != nil { resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) } + resp = c.wrapCompletionWithAutoContinue(ctx, a, payload, resp) return resp, nil } if captureSession != nil { diff --git a/internal/deepseek/client_continue.go b/internal/deepseek/client_continue.go new file mode 100644 index 0000000..f5a33c6 --- /dev/null +++ b/internal/deepseek/client_continue.go @@ -0,0 +1,240 @@ +package deepseek + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "strings" + + "ds2api/internal/auth" + "ds2api/internal/config" +) + +const defaultAutoContinueLimit = 8 + +type continueOpenFunc func(context.Context, string, int) (*http.Response, error) + +type continueState struct { + sessionID string + responseMessageID int + lastStatus string + finished bool +} + +// wrapCompletionWithAutoContinue wraps the completion response body so that +// if the upstream indicates the response is incomplete (WIP / INCOMPLETE / +// AUTO_CONTINUE), ds2api will automatically call the DeepSeek continue +// endpoint and splice the continuation SSE stream onto the original. +// The caller sees a single, seamless SSE stream. +func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, resp *http.Response) *http.Response { + if resp == nil || resp.Body == nil { + return resp + } + sessionID, _ := payload["chat_session_id"].(string) + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return resp + } + config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID) + resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) { + return c.callContinue(ctx, a, sessionID, responseMessageID) + }) + return resp +} + +// callContinue sends a continue request to DeepSeek to resume generation. +func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int) (*http.Response, error) { + if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 { + return nil, errors.New("missing continue identifiers") + } + headers := c.authHeaders(a.DeepSeekToken) + payload := map[string]any{ + "chat_session_id": sessionID, + "message_id": responseMessageID, + "fallback_to_resume": true, + } + config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID) + captureSession := c.capture.Start("deepseek_continue", DeepSeekContinueURL, a.AccountID, payload) + resp, err := c.streamPost(ctx, DeepSeekContinueURL, headers, payload) + if err != nil { + return nil, err + } + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, errors.New("continue failed") + } + return resp, nil +} + +// newAutoContinueBody returns a new ReadCloser that transparently pumps +// continuation rounds via an io.Pipe. +func newAutoContinueBody(ctx context.Context, initial io.ReadCloser, sessionID string, maxRounds int, openContinue continueOpenFunc) io.ReadCloser { + if initial == nil || strings.TrimSpace(sessionID) == "" || openContinue == nil { + return initial + } + if maxRounds <= 0 { + maxRounds = defaultAutoContinueLimit + } + pr, pw := io.Pipe() + go pumpAutoContinue(ctx, pw, initial, continueState{sessionID: sessionID}, maxRounds, openContinue) + return pr +} + +// pumpAutoContinue is the goroutine that drives the auto-continue loop. +// It reads the initial SSE body, checks whether a continue is required, +// and if so opens a new continue stream and splices it onto the pipe writer. +func pumpAutoContinue(ctx context.Context, pw *io.PipeWriter, initial io.ReadCloser, state continueState, maxRounds int, openContinue continueOpenFunc) { + defer func() { _ = pw.Close() }() + current := initial + rounds := 0 + for { + hadDone, err := streamBodyWithContinueState(ctx, pw, current, &state) + _ = current.Close() + if err != nil { + _ = pw.CloseWithError(err) + return + } + if state.shouldContinue() && rounds < maxRounds { + rounds++ + config.Logger.Info("[auto_continue] continuing", "round", rounds, "session_id", state.sessionID, "message_id", state.responseMessageID, "status", state.lastStatus) + nextResp, err := openContinue(ctx, state.sessionID, state.responseMessageID) + if err != nil { + config.Logger.Warn("[auto_continue] continue request failed", "round", rounds, "error", err) + _ = pw.CloseWithError(err) + return + } + current = nextResp.Body + state.prepareForNextRound() + continue + } + // Emit the final [DONE] sentinel if the upstream had one. + if hadDone { + if _, err := io.Copy(pw, bytes.NewBufferString("data: [DONE]\n")); err != nil { + _ = pw.CloseWithError(err) + } + } + return + } +} + +// streamBodyWithContinueState scans an SSE body line-by-line, writing each +// line through to pw while observing state signals. Intermediate [DONE] +// sentinels are consumed (not forwarded) so that the downstream only sees +// one final [DONE] at the very end. +func streamBodyWithContinueState(ctx context.Context, pw *io.PipeWriter, body io.Reader, state *continueState) (bool, error) { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + hadDone := false + for scanner.Scan() { + select { + case <-ctx.Done(): + return hadDone, ctx.Err() + default: + } + line := append([]byte{}, scanner.Bytes()...) + trimmed := strings.TrimSpace(string(line)) + if trimmed == "" { + continue + } + if strings.HasPrefix(trimmed, "data:") { + data := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if data == "[DONE]" { + hadDone = true + continue + } + state.observe(data) + } + if _, err := io.Copy(pw, bytes.NewReader(append(line, '\n'))); err != nil { + return hadDone, err + } + } + return hadDone, scanner.Err() +} + +// observe extracts continue-relevant signals from an SSE JSON chunk. +func (s *continueState) observe(data string) { + if s == nil || strings.TrimSpace(data) == "" { + return + } + var chunk map[string]any + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + return + } + // Top-level response_message_id + if id := intFrom(chunk["response_message_id"]); id > 0 { + s.responseMessageID = id + } + // Path-based status: {"p": "response/status", "v": "FINISHED"} + if p, _ := chunk["p"].(string); p == "response/status" { + if status, _ := chunk["v"].(string); status != "" { + s.lastStatus = strings.TrimSpace(status) + if strings.EqualFold(s.lastStatus, "FINISHED") { + s.finished = true + } + } + } + // Nested v.response + v, _ := chunk["v"].(map[string]any) + if response, _ := v["response"].(map[string]any); response != nil { + if id := intFrom(response["message_id"]); id > 0 { + s.responseMessageID = id + } + if status, _ := response["status"].(string); status != "" { + s.lastStatus = strings.TrimSpace(status) + if strings.EqualFold(s.lastStatus, "FINISHED") { + s.finished = true + } + } + if autoContinue, ok := response["auto_continue"].(bool); ok && autoContinue { + s.lastStatus = "AUTO_CONTINUE" + } + } + // Nested message.response + if message, _ := chunk["message"].(map[string]any); message != nil { + if response, _ := message["response"].(map[string]any); response != nil { + if id := intFrom(response["message_id"]); id > 0 { + s.responseMessageID = id + } + if status, _ := response["status"].(string); status != "" { + s.lastStatus = strings.TrimSpace(status) + if strings.EqualFold(s.lastStatus, "FINISHED") { + s.finished = true + } + } + } + } +} + +// shouldContinue returns true when the upstream indicates the response is +// not yet finished and we have enough information to issue a continue request. +func (s *continueState) shouldContinue() bool { + if s == nil { + return false + } + if s.finished || s.responseMessageID <= 0 || strings.TrimSpace(s.sessionID) == "" { + return false + } + switch strings.ToUpper(strings.TrimSpace(s.lastStatus)) { + case "WIP", "INCOMPLETE", "AUTO_CONTINUE": + return true + default: + return false + } +} + +// prepareForNextRound resets ephemeral state before processing the next +// continuation stream. +func (s *continueState) prepareForNextRound() { + if s == nil { + return + } + s.finished = false + s.lastStatus = "" +} diff --git a/internal/deepseek/constants.go b/internal/deepseek/constants.go index f35332a..bd7c858 100644 --- a/internal/deepseek/constants.go +++ b/internal/deepseek/constants.go @@ -11,6 +11,7 @@ const ( DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create" DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge" DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion" + DeepSeekContinueURL = "https://chat.deepseek.com/api/v0/chat/continue" DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page" DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete" DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"