From 0b0cf60982be019c5a5b5b7eb19a34bacae40280 Mon Sep 17 00:00:00 2001 From: CJACK Date: Sun, 5 Apr 2026 14:33:09 +0800 Subject: [PATCH] feat: propagate Proof-of-Work header to auto-continue requests and ensure remote session deletion ignores parent context cancellation --- internal/adapter/openai/handler_chat.go | 3 +- .../openai/handler_chat_auto_delete_test.go | 44 ++++++ internal/deepseek/client_completion.go | 4 +- internal/deepseek/client_continue.go | 7 +- internal/deepseek/client_continue_test.go | 137 ++++++++++++++++++ internal/deepseek/client_http_json.go | 4 +- 6 files changed, 191 insertions(+), 8 deletions(-) create mode 100644 internal/deepseek/client_continue_test.go diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go index 58a7cb0..9c2924f 100644 --- a/internal/adapter/openai/handler_chat.go +++ b/internal/adapter/openai/handler_chat.go @@ -87,7 +87,8 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu return } - deleteCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + deleteBaseCtx := context.WithoutCancel(ctx) + deleteCtx, cancel := context.WithTimeout(deleteBaseCtx, 10*time.Second) defer cancel() switch mode { diff --git a/internal/adapter/openai/handler_chat_auto_delete_test.go b/internal/adapter/openai/handler_chat_auto_delete_test.go index fbeca15..0196db0 100644 --- a/internal/adapter/openai/handler_chat_auto_delete_test.go +++ b/internal/adapter/openai/handler_chat_auto_delete_test.go @@ -16,6 +16,7 @@ type autoDeleteModeDSStub struct { singleCalls int allCalls int lastSessionID string + lastCtxErr error } func (m *autoDeleteModeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { @@ -41,6 +42,13 @@ func (m *autoDeleteModeDSStub) DeleteAllSessionsForToken(_ context.Context, _ st return nil } +func (m *autoDeleteModeDSStub) DeleteSessionForTokenCtx(ctx context.Context, _ string, sessionID string) (*deepseek.DeleteSessionResult, error) { + m.singleCalls++ + m.lastSessionID = sessionID + m.lastCtxErr = ctx.Err() + return &deepseek.DeleteSessionResult{SessionID: sessionID, Success: true}, nil +} + func TestChatCompletionsAutoDeleteModes(t *testing.T) { tests := []struct { name string @@ -93,3 +101,39 @@ func TestChatCompletionsAutoDeleteModes(t *testing.T) { }) } } + +type autoDeleteCtxDSStub struct { + autoDeleteModeDSStub +} + +func (m *autoDeleteCtxDSStub) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*deepseek.DeleteSessionResult, error) { + return m.autoDeleteModeDSStub.DeleteSessionForTokenCtx(ctx, token, sessionID) +} + +func (m *autoDeleteCtxDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error { + m.allCalls++ + return nil +} + +func TestAutoDeleteRemoteSessionIgnoresCanceledParentContext(t *testing.T) { + ds := &autoDeleteCtxDSStub{} + h := &Handler{ + Store: mockOpenAIConfig{ + wideInput: true, + autoDeleteMode: "single", + }, + DS: ds, + } + a := &auth.RequestAuth{DeepSeekToken: "token", AccountID: "acct"} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + h.autoDeleteRemoteSession(ctx, a, "session-id") + + if ds.singleCalls != 1 { + t.Fatalf("single delete calls=%d want=1", ds.singleCalls) + } + if ds.lastCtxErr != nil { + t.Fatalf("delete ctx should not inherit cancellation, got %v", ds.lastCtxErr) + } +} diff --git a/internal/deepseek/client_completion.go b/internal/deepseek/client_completion.go index d9b65b8..8f24cdd 100644 --- a/internal/deepseek/client_completion.go +++ b/internal/deepseek/client_completion.go @@ -31,7 +31,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa if captureSession != nil { resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) } - resp = c.wrapCompletionWithAutoContinue(ctx, a, payload, resp) + resp = c.wrapCompletionWithAutoContinue(ctx, a, payload, powResp, resp) return resp, nil } if captureSession != nil { @@ -61,7 +61,7 @@ func (c *Client) streamPost(ctx context.Context, url string, headers map[string] config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err) req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) if reqErr != nil { - return nil, err + return nil, reqErr } for k, v := range headers { req2.Header.Set(k, v) diff --git a/internal/deepseek/client_continue.go b/internal/deepseek/client_continue.go index f5a33c6..605d9e5 100644 --- a/internal/deepseek/client_continue.go +++ b/internal/deepseek/client_continue.go @@ -30,7 +30,7 @@ type continueState struct { // AUTO_CONTINUE), ds2api will automatically call the DeepSeek continue // endpoint and splice the continuation SSE stream onto the original. // The caller sees a single, seamless SSE stream. -func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, resp *http.Response) *http.Response { +func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, resp *http.Response) *http.Response { if resp == nil || resp.Body == nil { return resp } @@ -41,17 +41,18 @@ func (c *Client) wrapCompletionWithAutoContinue(ctx context.Context, a *auth.Req } config.Logger.Debug("[auto_continue] wrapping completion response", "session_id", sessionID) resp.Body = newAutoContinueBody(ctx, resp.Body, sessionID, defaultAutoContinueLimit, func(ctx context.Context, sessionID string, responseMessageID int) (*http.Response, error) { - return c.callContinue(ctx, a, sessionID, responseMessageID) + return c.callContinue(ctx, a, sessionID, responseMessageID, powResp) }) return resp } // callContinue sends a continue request to DeepSeek to resume generation. -func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int) (*http.Response, error) { +func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionID string, responseMessageID int, powResp string) (*http.Response, error) { if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 { return nil, errors.New("missing continue identifiers") } headers := c.authHeaders(a.DeepSeekToken) + headers["x-ds-pow-response"] = powResp payload := map[string]any{ "chat_session_id": sessionID, "message_id": responseMessageID, diff --git a/internal/deepseek/client_continue_test.go b/internal/deepseek/client_continue_test.go new file mode 100644 index 0000000..68963e7 --- /dev/null +++ b/internal/deepseek/client_continue_test.go @@ -0,0 +1,137 @@ +package deepseek + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + + "ds2api/internal/auth" +) + +type failingDoer struct { + err error +} + +func (d failingDoer) Do(_ *http.Request) (*http.Response, error) { + return nil, d.err +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestCallContinuePropagatesPowHeaderToFallbackRequest(t *testing.T) { + var seenPow string + var seenURL string + + client := &Client{ + stream: failingDoer{err: errors.New("stream transport failed")}, + fallbackS: &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + seenPow = req.Header.Get("x-ds-pow-response") + seenURL = req.URL.String() + body := io.NopCloser(strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"continued\"}\n" + "data: [DONE]\n")) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: body, + Request: req, + }, nil + }), + }, + } + + resp, err := client.callContinue(context.Background(), &auth.RequestAuth{ + DeepSeekToken: "token", + AccountID: "acct", + }, "session-123", 99, "pow-response-abc") + if err != nil { + t.Fatalf("callContinue returned error: %v", err) + } + defer resp.Body.Close() + + if seenPow != "pow-response-abc" { + t.Fatalf("continue request pow header=%q want=%q", seenPow, "pow-response-abc") + } + if seenURL != DeepSeekContinueURL { + t.Fatalf("continue request url=%q want=%q", seenURL, DeepSeekContinueURL) + } +} + +func TestCallCompletionAutoContinueThreadsPowHeader(t *testing.T) { + var seenPow string + var seenContinueURL string + + initialBody := strings.Join([]string{ + `data: {"response_message_id":321,"v":{"response":{"message_id":321,"status":"WIP","auto_continue":true}}}`, + `data: [DONE]`, + }, "\n") + "\n" + + client := &Client{ + stream: failingOrCompletionDoer{ + completionResp: &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(initialBody)), + }, + }, + fallbackS: &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + seenPow = req.Header.Get("x-ds-pow-response") + seenContinueURL = req.URL.String() + body := io.NopCloser(strings.NewReader("data: {\"response_message_id\":322,\"v\":{\"response\":{\"message_id\":322,\"status\":\"FINISHED\"}}}\n" + "data: [DONE]\n")) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: body, + Request: req, + }, nil + }), + }, + } + + resp, err := client.CallCompletion(context.Background(), &auth.RequestAuth{ + DeepSeekToken: "token", + AccountID: "acct", + }, map[string]any{ + "chat_session_id": "session-123", + }, "pow-response-xyz", 1) + if err != nil { + t.Fatalf("CallCompletion returned error: %v", err) + } + defer resp.Body.Close() + + out, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read auto-continued body failed: %v", err) + } + if seenPow != "pow-response-xyz" { + t.Fatalf("threaded continue pow header=%q want=%q", seenPow, "pow-response-xyz") + } + if seenContinueURL != DeepSeekContinueURL { + t.Fatalf("continue url=%q want=%q", seenContinueURL, DeepSeekContinueURL) + } + if !bytes.Contains(out, []byte(`"status":"WIP"`)) { + t.Fatalf("expected initial stream content in body, got=%s", string(out)) + } + if !bytes.Contains(out, []byte(`data: [DONE]`)) { + t.Fatalf("expected final DONE sentinel in body, got=%s", string(out)) + } +} + +type failingOrCompletionDoer struct { + completionResp *http.Response +} + +func (d failingOrCompletionDoer) Do(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "/chat/completion") { + return d.completionResp, nil + } + return nil, errors.New("forced stream failure") +} diff --git a/internal/deepseek/client_http_json.go b/internal/deepseek/client_http_json.go index a35d736..9de0e57 100644 --- a/internal/deepseek/client_http_json.go +++ b/internal/deepseek/client_http_json.go @@ -39,7 +39,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err) req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) if reqErr != nil { - return nil, 0, err + return nil, 0, reqErr } for k, v := range headers { req2.Header.Set(k, v) @@ -76,7 +76,7 @@ func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url str config.Logger.Warn("[deepseek] fingerprint GET request failed, fallback to std transport", "url", url, "error", err) req2, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if reqErr != nil { - return nil, 0, err + return nil, 0, reqErr } for k, v := range headers { req2.Header.Set(k, v)