mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-10 11:17:41 +08:00
feat: implement managed-account rotation on 429 empty-output completion retries
This commit is contained in:
@@ -104,6 +104,7 @@ func ExecuteNonStreamStartedWithRetry(ctx context.Context, ds DeepSeekCaller, a
|
||||
pow := start.Pow
|
||||
|
||||
attempts := 0
|
||||
accountSwitchAttempted := false
|
||||
currentResp := start.Response
|
||||
usagePrompt := stdReq.PromptTokenText
|
||||
accumulatedThinking := ""
|
||||
@@ -112,6 +113,24 @@ func ExecuteNonStreamStartedWithRetry(ctx context.Context, ds DeepSeekCaller, a
|
||||
for {
|
||||
turn, outErr := collectAttempt(currentResp, stdReq, usagePrompt, opts)
|
||||
if outErr != nil {
|
||||
if canRetryOnAlternateAccount(ctx, a, outErr, opts.RetryEnabled, &accountSwitchAttempted) {
|
||||
switched, switchErr := startStandardCompletionOnAlternateAccount(ctx, ds, a, stdReq, maxAttempts)
|
||||
if switchErr != nil {
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Attempts: attempts}, switchErr
|
||||
}
|
||||
if switched.Response != nil {
|
||||
config.Logger.Info("[completion_runtime_account_switch_retry] retrying after 429", "surface", stdReq.Surface, "stream", false, "account", a.AccountID)
|
||||
sessionID = switched.SessionID
|
||||
payload = switched.Payload
|
||||
pow = switched.Pow
|
||||
currentResp = switched.Response
|
||||
usagePrompt = stdReq.PromptTokenText
|
||||
accumulatedThinking = ""
|
||||
accumulatedRawThinking = ""
|
||||
accumulatedToolDetectionThinking = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Attempts: attempts}, outErr
|
||||
}
|
||||
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, turn.Thinking)
|
||||
@@ -134,6 +153,24 @@ func ExecuteNonStreamStartedWithRetry(ctx context.Context, ds DeepSeekCaller, a
|
||||
retryMax = shared.EmptyOutputRetryMaxAttempts()
|
||||
}
|
||||
if !opts.RetryEnabled || !assistantturn.ShouldRetryEmptyOutput(turn, attempts, retryMax) {
|
||||
if canRetryOnAlternateAccount(ctx, a, turn.Error, opts.RetryEnabled, &accountSwitchAttempted) {
|
||||
switched, switchErr := startStandardCompletionOnAlternateAccount(ctx, ds, a, stdReq, maxAttempts)
|
||||
if switchErr != nil {
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, switchErr
|
||||
}
|
||||
if switched.Response != nil {
|
||||
config.Logger.Info("[completion_runtime_account_switch_retry] retrying after 429", "surface", stdReq.Surface, "stream", false, "account", a.AccountID)
|
||||
sessionID = switched.SessionID
|
||||
payload = switched.Payload
|
||||
pow = switched.Pow
|
||||
currentResp = switched.Response
|
||||
usagePrompt = stdReq.PromptTokenText
|
||||
accumulatedThinking = ""
|
||||
accumulatedRawThinking = ""
|
||||
accumulatedToolDetectionThinking = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, turn.Error
|
||||
}
|
||||
|
||||
@@ -154,6 +191,37 @@ func ExecuteNonStreamStartedWithRetry(ctx context.Context, ds DeepSeekCaller, a
|
||||
}
|
||||
}
|
||||
|
||||
func canRetryOnAlternateAccount(ctx context.Context, a *auth.RequestAuth, outErr *assistantturn.OutputError, retryEnabled bool, attempted *bool) bool {
|
||||
if outErr == nil || outErr.Status != http.StatusTooManyRequests {
|
||||
return false
|
||||
}
|
||||
if !retryEnabled || attempted == nil || *attempted {
|
||||
return false
|
||||
}
|
||||
if a == nil || !a.UseConfigToken {
|
||||
return false
|
||||
}
|
||||
*attempted = true
|
||||
return a.SwitchAccount(ctx)
|
||||
}
|
||||
|
||||
func startStandardCompletionOnAlternateAccount(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, maxAttempts int) (StartResult, *assistantturn.OutputError) {
|
||||
sessionID, err := ds.CreateSession(ctx, a, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{}, authOutputError(a)
|
||||
}
|
||||
pow, err := ds.GetPow(ctx, a, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{SessionID: sessionID}, &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Failed to get PoW (invalid token or unknown error).", Code: "error"}
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := ds.CallCompletion(ctx, a, payload, pow, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
|
||||
}
|
||||
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow, Response: resp, Request: stdReq}, nil
|
||||
}
|
||||
|
||||
func collectAttempt(resp *http.Response, stdReq promptcompat.StandardRequest, usagePrompt string, opts Options) (assistantturn.Turn, *assistantturn.OutputError) {
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
|
||||
@@ -7,15 +7,19 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
"ds2api/internal/promptcompat"
|
||||
)
|
||||
|
||||
type fakeDeepSeekCaller struct {
|
||||
responses []*http.Response
|
||||
payloads []map[string]any
|
||||
uploads []dsclient.UploadFileRequest
|
||||
responses []*http.Response
|
||||
payloads []map[string]any
|
||||
uploads []dsclient.UploadFileRequest
|
||||
completionAccounts []string
|
||||
sessionByAccount bool
|
||||
}
|
||||
|
||||
type currentInputRuntimeConfig struct{}
|
||||
@@ -23,7 +27,10 @@ type currentInputRuntimeConfig struct{}
|
||||
func (currentInputRuntimeConfig) CurrentInputFileEnabled() bool { return true }
|
||||
func (currentInputRuntimeConfig) CurrentInputFileMinChars() int { return 0 }
|
||||
|
||||
func (f *fakeDeepSeekCaller) CreateSession(context.Context, *auth.RequestAuth, int) (string, error) {
|
||||
func (f *fakeDeepSeekCaller) CreateSession(_ context.Context, a *auth.RequestAuth, _ int) (string, error) {
|
||||
if f.sessionByAccount && a != nil && a.AccountID != "" {
|
||||
return "session-" + a.AccountID, nil
|
||||
}
|
||||
return "session-1", nil
|
||||
}
|
||||
|
||||
@@ -36,8 +43,11 @@ func (f *fakeDeepSeekCaller) UploadFile(_ context.Context, _ *auth.RequestAuth,
|
||||
return &dsclient.UploadFileResult{ID: "file-runtime-1"}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, _ *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, a *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
f.payloads = append(f.payloads, payload)
|
||||
if a != nil {
|
||||
f.completionAccounts = append(f.completionAccounts, a.AccountID)
|
||||
}
|
||||
if len(f.responses) == 0 {
|
||||
return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil
|
||||
}
|
||||
@@ -89,6 +99,69 @@ func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteNonStreamWithRetrySwitchesManagedAccountBeforeFinal429(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"acc1@test.com","password":"pwd"},
|
||||
{"email":"acc2@test.com","password":"pwd"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
resolver := auth.NewResolver(store, account.NewPool(store), func(_ context.Context, acc config.Account) (string, error) {
|
||||
return "token-" + acc.Identifier(), nil
|
||||
})
|
||||
req, _ := http.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
a, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer resolver.Release(a)
|
||||
|
||||
ds := &fakeDeepSeekCaller{
|
||||
sessionByAccount: true,
|
||||
responses: []*http.Response{
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":11,"p":"response/thinking_content","v":"first empty"}`),
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":12,"p":"response/thinking_content","v":"retry empty"}`),
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":21,"p":"response/content","v":"ok from second account"}`),
|
||||
},
|
||||
}
|
||||
stdReq := promptcompat.StandardRequest{
|
||||
Surface: "test",
|
||||
ResponseModel: "deepseek-v4-flash",
|
||||
PromptTokenText: "prompt",
|
||||
FinalPrompt: "final prompt",
|
||||
Thinking: true,
|
||||
}
|
||||
|
||||
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, a, stdReq, Options{RetryEnabled: true})
|
||||
if outErr != nil {
|
||||
t.Fatalf("unexpected output error after account switch retry: %#v", outErr)
|
||||
}
|
||||
if result.Turn.Text != "ok from second account" {
|
||||
t.Fatalf("text mismatch after switch retry: %q", result.Turn.Text)
|
||||
}
|
||||
if result.SessionID != "session-acc2@test.com" {
|
||||
t.Fatalf("expected switched account session, got %q", result.SessionID)
|
||||
}
|
||||
wantAccounts := []string{"acc1@test.com", "acc1@test.com", "acc2@test.com"}
|
||||
if len(ds.completionAccounts) != len(wantAccounts) {
|
||||
t.Fatalf("completion account count mismatch: got %v want %v", ds.completionAccounts, wantAccounts)
|
||||
}
|
||||
for i, want := range wantAccounts {
|
||||
if ds.completionAccounts[i] != want {
|
||||
t.Fatalf("completion account %d = %q want %q (all=%v)", i, ds.completionAccounts[i], want, ds.completionAccounts)
|
||||
}
|
||||
}
|
||||
if got := ds.payloads[2]["chat_session_id"]; got != "session-acc2@test.com" {
|
||||
t.Fatalf("switched payload session mismatch: %#v", got)
|
||||
}
|
||||
if prompt, _ := ds.payloads[2]["prompt"].(string); strings.Contains(prompt, "Previous reply had no visible output") {
|
||||
t.Fatalf("expected fresh switched-account prompt without empty-output suffix, got %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) {
|
||||
ds := &fakeDeepSeekCaller{responses: []*http.Response{
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/thinking_content","v":"plan"}`),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
@@ -27,6 +28,7 @@ type StreamRetryHooks struct {
|
||||
OnRetry func(attempts int)
|
||||
OnRetryPrompt func(prompt string)
|
||||
OnRetryFailure func(status int, message, code string)
|
||||
OnAccountSwitch func(sessionID string)
|
||||
OnTerminal func(attempts int)
|
||||
}
|
||||
|
||||
@@ -48,16 +50,48 @@ func ExecuteStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.Requ
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
accountSwitchAttempted := false
|
||||
currentResp := initialResp
|
||||
currentPayload := clonePayload(payload)
|
||||
for {
|
||||
terminalWritten, retryable := hooks.ConsumeAttempt(currentResp, opts.RetryEnabled && attempts < retryMax)
|
||||
allowAccountSwitch := opts.RetryEnabled && attempts >= retryMax && !accountSwitchAttempted && a != nil && a.UseConfigToken
|
||||
terminalWritten, retryable := hooks.ConsumeAttempt(currentResp, opts.RetryEnabled && (attempts < retryMax || allowAccountSwitch))
|
||||
if terminalWritten {
|
||||
if hooks.OnTerminal != nil {
|
||||
hooks.OnTerminal(attempts)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !retryable || !opts.RetryEnabled || attempts >= retryMax {
|
||||
if !retryable || !opts.RetryEnabled {
|
||||
if hooks.Finalize != nil {
|
||||
hooks.Finalize(attempts)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if attempts >= retryMax {
|
||||
if canRetryOnAlternateAccount(ctx, a, &assistantturn.OutputError{Status: http.StatusTooManyRequests}, opts.RetryEnabled, &accountSwitchAttempted) {
|
||||
switched, switchErr := startPayloadCompletionOnAlternateAccount(ctx, ds, a, payload, maxAttempts)
|
||||
if switchErr != nil {
|
||||
if hooks.OnRetryFailure != nil {
|
||||
hooks.OnRetryFailure(switchErr.Status, switchErr.Message, switchErr.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
if switched.Response != nil {
|
||||
config.Logger.Info("[completion_runtime_account_switch_retry] retrying after 429", "surface", surface, "stream", opts.Stream, "account", a.AccountID)
|
||||
currentResp = switched.Response
|
||||
currentPayload = switched.Payload
|
||||
pow = switched.Pow
|
||||
if hooks.OnAccountSwitch != nil {
|
||||
hooks.OnAccountSwitch(switched.SessionID)
|
||||
}
|
||||
if hooks.OnRetryPrompt != nil {
|
||||
hooks.OnRetryPrompt(opts.UsagePrompt)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if hooks.Finalize != nil {
|
||||
hooks.Finalize(attempts)
|
||||
}
|
||||
@@ -75,7 +109,7 @@ func ExecuteStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.Requ
|
||||
config.Logger.Warn("[completion_runtime_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", surface, "stream", opts.Stream, "retry_attempt", attempts, "error", powErr)
|
||||
retryPow = pow
|
||||
}
|
||||
nextResp, err := ds.CallCompletion(ctx, a, shared.ClonePayloadForEmptyOutputRetry(payload, parentMessageID), retryPow, maxAttempts)
|
||||
nextResp, err := ds.CallCompletion(ctx, a, shared.ClonePayloadForEmptyOutputRetry(currentPayload, parentMessageID), retryPow, maxAttempts)
|
||||
if err != nil {
|
||||
if hooks.OnRetryFailure != nil {
|
||||
hooks.OnRetryFailure(http.StatusInternalServerError, "Failed to get completion.", "error")
|
||||
@@ -108,6 +142,33 @@ func ExecuteStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.Requ
|
||||
}
|
||||
}
|
||||
|
||||
func startPayloadCompletionOnAlternateAccount(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, payload map[string]any, maxAttempts int) (StartResult, *assistantturn.OutputError) {
|
||||
sessionID, err := ds.CreateSession(ctx, a, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{}, authOutputError(a)
|
||||
}
|
||||
pow, err := ds.GetPow(ctx, a, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{SessionID: sessionID}, &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Failed to get PoW (invalid token or unknown error).", Code: "error"}
|
||||
}
|
||||
nextPayload := clonePayload(payload)
|
||||
nextPayload["chat_session_id"] = sessionID
|
||||
delete(nextPayload, "parent_message_id")
|
||||
resp, err := ds.CallCompletion(ctx, a, nextPayload, pow, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{SessionID: sessionID, Payload: nextPayload, Pow: pow}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
|
||||
}
|
||||
return StartResult{SessionID: sessionID, Payload: nextPayload, Pow: pow, Response: resp}, nil
|
||||
}
|
||||
|
||||
func clonePayload(payload map[string]any) map[string]any {
|
||||
clone := make(map[string]any, len(payload))
|
||||
for k, v := range payload {
|
||||
clone[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func closeRetryBody(surface string, body io.Closer) {
|
||||
if body == nil {
|
||||
return
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
)
|
||||
|
||||
@@ -60,3 +62,89 @@ func TestExecuteStreamWithRetryUsesSharedRetryPayloadAndUsagePrompt(t *testing.T
|
||||
t.Fatalf("expected retry suffix in usage prompt, got %q", retryPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithRetrySwitchesManagedAccountBeforeFinal429(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"acc1@test.com","password":"pwd"},
|
||||
{"email":"acc2@test.com","password":"pwd"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
resolver := auth.NewResolver(store, account.NewPool(store), func(_ context.Context, acc config.Account) (string, error) {
|
||||
return "token-" + acc.Identifier(), nil
|
||||
})
|
||||
req, _ := http.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
a, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer resolver.Release(a)
|
||||
|
||||
ds := &fakeDeepSeekCaller{
|
||||
sessionByAccount: true,
|
||||
responses: []*http.Response{
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":12,"p":"response/thinking_content","v":"retry empty"}`),
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":21,"p":"response/content","v":"ok from second account"}`),
|
||||
},
|
||||
}
|
||||
initial := sseHTTPResponse(http.StatusOK, `data: {"response_message_id":11,"p":"response/thinking_content","v":"first empty"}`)
|
||||
payload := map[string]any{"prompt": "original prompt", "chat_session_id": "session-acc1@test.com"}
|
||||
attemptsSeen := 0
|
||||
switchedSession := ""
|
||||
|
||||
ExecuteStreamWithRetry(context.Background(), ds, a, initial, payload, "pow", StreamRetryOptions{
|
||||
Surface: "test.stream",
|
||||
Stream: true,
|
||||
RetryEnabled: true,
|
||||
RetryMaxAttempts: 1,
|
||||
UsagePrompt: "original prompt",
|
||||
}, StreamRetryHooks{
|
||||
ConsumeAttempt: func(resp *http.Response, allowDeferEmpty bool) (bool, bool) {
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatalf("close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
attemptsSeen++
|
||||
if strings.Contains(string(body), "ok from second account") {
|
||||
return true, false
|
||||
}
|
||||
if !allowDeferEmpty {
|
||||
t.Fatalf("expected empty attempt %d to be deferred before final 429", attemptsSeen)
|
||||
}
|
||||
return false, true
|
||||
},
|
||||
ParentMessageID: func() int {
|
||||
return 11 + attemptsSeen
|
||||
},
|
||||
OnAccountSwitch: func(sessionID string) {
|
||||
switchedSession = sessionID
|
||||
},
|
||||
})
|
||||
|
||||
if attemptsSeen != 3 {
|
||||
t.Fatalf("expected three stream attempts, got %d", attemptsSeen)
|
||||
}
|
||||
if switchedSession != "session-acc2@test.com" {
|
||||
t.Fatalf("expected switched session id, got %q", switchedSession)
|
||||
}
|
||||
wantAccounts := []string{"acc1@test.com", "acc2@test.com"}
|
||||
if len(ds.completionAccounts) != len(wantAccounts) {
|
||||
t.Fatalf("completion accounts mismatch: got %v want %v", ds.completionAccounts, wantAccounts)
|
||||
}
|
||||
for i, want := range wantAccounts {
|
||||
if ds.completionAccounts[i] != want {
|
||||
t.Fatalf("completion account %d = %q want %q (all=%v)", i, ds.completionAccounts[i], want, ds.completionAccounts)
|
||||
}
|
||||
}
|
||||
if got := ds.payloads[1]["chat_session_id"]; got != "session-acc2@test.com" {
|
||||
t.Fatalf("switched payload session mismatch: %#v", got)
|
||||
}
|
||||
if prompt, _ := ds.payloads[1]["prompt"].(string); strings.Contains(prompt, shared.EmptyOutputRetrySuffix) {
|
||||
t.Fatalf("expected switched-account prompt without empty-output suffix, got %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user