mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
feat: propagate Proof-of-Work header to auto-continue requests and ensure remote session deletion ignores parent context cancellation
This commit is contained in:
@@ -87,7 +87,8 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
||||
return
|
||||
}
|
||||
|
||||
deleteCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
deleteBaseCtx := context.WithoutCancel(ctx)
|
||||
deleteCtx, cancel := context.WithTimeout(deleteBaseCtx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch mode {
|
||||
|
||||
@@ -16,6 +16,7 @@ type autoDeleteModeDSStub struct {
|
||||
singleCalls int
|
||||
allCalls int
|
||||
lastSessionID string
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (m *autoDeleteModeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
@@ -41,6 +42,13 @@ func (m *autoDeleteModeDSStub) DeleteAllSessionsForToken(_ context.Context, _ st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *autoDeleteModeDSStub) DeleteSessionForTokenCtx(ctx context.Context, _ string, sessionID string) (*deepseek.DeleteSessionResult, error) {
|
||||
m.singleCalls++
|
||||
m.lastSessionID = sessionID
|
||||
m.lastCtxErr = ctx.Err()
|
||||
return &deepseek.DeleteSessionResult{SessionID: sessionID, Success: true}, nil
|
||||
}
|
||||
|
||||
func TestChatCompletionsAutoDeleteModes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -93,3 +101,39 @@ func TestChatCompletionsAutoDeleteModes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type autoDeleteCtxDSStub struct {
|
||||
autoDeleteModeDSStub
|
||||
}
|
||||
|
||||
func (m *autoDeleteCtxDSStub) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*deepseek.DeleteSessionResult, error) {
|
||||
return m.autoDeleteModeDSStub.DeleteSessionForTokenCtx(ctx, token, sessionID)
|
||||
}
|
||||
|
||||
func (m *autoDeleteCtxDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
m.allCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAutoDeleteRemoteSessionIgnoresCanceledParentContext(t *testing.T) {
|
||||
ds := &autoDeleteCtxDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
autoDeleteMode: "single",
|
||||
},
|
||||
DS: ds,
|
||||
}
|
||||
a := &auth.RequestAuth{DeepSeekToken: "token", AccountID: "acct"}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
h.autoDeleteRemoteSession(ctx, a, "session-id")
|
||||
|
||||
if ds.singleCalls != 1 {
|
||||
t.Fatalf("single delete calls=%d want=1", ds.singleCalls)
|
||||
}
|
||||
if ds.lastCtxErr != nil {
|
||||
t.Fatalf("delete ctx should not inherit cancellation, got %v", ds.lastCtxErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
resp = c.wrapCompletionWithAutoContinue(ctx, a, payload, resp)
|
||||
resp = c.wrapCompletionWithAutoContinue(ctx, a, payload, powResp, resp)
|
||||
return resp, nil
|
||||
}
|
||||
if captureSession != nil {
|
||||
@@ -61,7 +61,7 @@ func (c *Client) streamPost(ctx context.Context, url string, headers map[string]
|
||||
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, err
|
||||
return nil, reqErr
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
|
||||
@@ -30,7 +30,7 @@ type continueState struct {
|
||||
// AUTO_CONTINUE), ds2api will automatically call the DeepSeek continue
|
||||
// endpoint and splice the continuation SSE stream onto the original.
|
||||
// The caller sees a single, seamless SSE stream.
|
||||
func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, resp *http.Response) *http.Response {
|
||||
func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, resp *http.Response) *http.Response {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return resp
|
||||
}
|
||||
@@ -41,17 +41,18 @@ func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.Req
|
||||
}
|
||||
config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID)
|
||||
resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) {
|
||||
return c.callContinue(ctx, a, sessionID, responseMessageID)
|
||||
return c.callContinue(ctx, a, sessionID, responseMessageID, powResp)
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
// callContinue sends a continue request to DeepSeek to resume generation.
|
||||
func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int) (*http.Response, error) {
|
||||
func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int, powResp string) (*http.Response, error) {
|
||||
if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 {
|
||||
return nil, errors.New("missing continue identifiers")
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"message_id": responseMessageID,
|
||||
|
||||
137
internal/deepseek/client_continue_test.go
Normal file
137
internal/deepseek/client_continue_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type failingDoer struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (d failingDoer) Do(_ *http.Request) (*http.Response, error) {
|
||||
return nil, d.err
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestCallContinuePropagatesPowHeaderToFallbackRequest(t *testing.T) {
|
||||
var seenPow string
|
||||
var seenURL string
|
||||
|
||||
client := &Client{
|
||||
stream: failingDoer{err: errors.New("stream transport failed")},
|
||||
fallbackS: &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
seenPow = req.Header.Get("x-ds-pow-response")
|
||||
seenURL = req.URL.String()
|
||||
body := io.NopCloser(strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"continued\"}\n" + "data: [DONE]\n"))
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: body,
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.callContinue(context.Background(), &auth.RequestAuth{
|
||||
DeepSeekToken: "token",
|
||||
AccountID: "acct",
|
||||
}, "session-123", 99, "pow-response-abc")
|
||||
if err != nil {
|
||||
t.Fatalf("callContinue returned error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if seenPow != "pow-response-abc" {
|
||||
t.Fatalf("continue request pow header=%q want=%q", seenPow, "pow-response-abc")
|
||||
}
|
||||
if seenURL != DeepSeekContinueURL {
|
||||
t.Fatalf("continue request url=%q want=%q", seenURL, DeepSeekContinueURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallCompletionAutoContinueThreadsPowHeader(t *testing.T) {
|
||||
var seenPow string
|
||||
var seenContinueURL string
|
||||
|
||||
initialBody := strings.Join([]string{
|
||||
`data: {"response_message_id":321,"v":{"response":{"message_id":321,"status":"WIP","auto_continue":true}}}`,
|
||||
`data: [DONE]`,
|
||||
}, "\n") + "\n"
|
||||
|
||||
client := &Client{
|
||||
stream: failingOrCompletionDoer{
|
||||
completionResp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(initialBody)),
|
||||
},
|
||||
},
|
||||
fallbackS: &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
seenPow = req.Header.Get("x-ds-pow-response")
|
||||
seenContinueURL = req.URL.String()
|
||||
body := io.NopCloser(strings.NewReader("data: {\"response_message_id\":322,\"v\":{\"response\":{\"message_id\":322,\"status\":\"FINISHED\"}}}\n" + "data: [DONE]\n"))
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: body,
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.CallCompletion(context.Background(), &auth.RequestAuth{
|
||||
DeepSeekToken: "token",
|
||||
AccountID: "acct",
|
||||
}, map[string]any{
|
||||
"chat_session_id": "session-123",
|
||||
}, "pow-response-xyz", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallCompletion returned error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read auto-continued body failed: %v", err)
|
||||
}
|
||||
if seenPow != "pow-response-xyz" {
|
||||
t.Fatalf("threaded continue pow header=%q want=%q", seenPow, "pow-response-xyz")
|
||||
}
|
||||
if seenContinueURL != DeepSeekContinueURL {
|
||||
t.Fatalf("continue url=%q want=%q", seenContinueURL, DeepSeekContinueURL)
|
||||
}
|
||||
if !bytes.Contains(out, []byte(`"status":"WIP"`)) {
|
||||
t.Fatalf("expected initial stream content in body, got=%s", string(out))
|
||||
}
|
||||
if !bytes.Contains(out, []byte(`data: [DONE]`)) {
|
||||
t.Fatalf("expected final DONE sentinel in body, got=%s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
type failingOrCompletionDoer struct {
|
||||
completionResp *http.Response
|
||||
}
|
||||
|
||||
func (d failingOrCompletionDoer) Do(req *http.Request) (*http.Response, error) {
|
||||
if strings.Contains(req.URL.Path, "/chat/completion") {
|
||||
return d.completionResp, nil
|
||||
}
|
||||
return nil, errors.New("forced stream failure")
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
||||
config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, reqErr
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
@@ -76,7 +76,7 @@ func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url str
|
||||
config.Logger.Warn("[deepseek] fingerprint GET request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, reqErr
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
|
||||
Reference in New Issue
Block a user