From f6cd541c6f3236512c00f8db168d40214de03764 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Thu, 2 Apr 2026 02:04:58 +0800 Subject: [PATCH 1/2] auth: retry other managed accounts when token ensure fails --- internal/auth/auth_edge_test.go | 39 ++++++++++++++++++ internal/auth/request.go | 70 ++++++++++++++++++++++----------- internal/auth/request_test.go | 64 ++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 23 deletions(-) diff --git a/internal/auth/auth_edge_test.go b/internal/auth/auth_edge_test.go index 55c46ef..929b753 100644 --- a/internal/auth/auth_edge_test.go +++ b/internal/auth/auth_edge_test.go @@ -204,6 +204,45 @@ func TestSwitchAccountNilTriedAccounts(t *testing.T) { r.Release(a) } +func TestSwitchAccountSkipsLoginFailureAndContinues(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","password":"pwd","token":"t1"}, + {"email":"acc2@test.com","password":"pwd"}, + {"email":"acc3@test.com","password":"pwd","token":"t3"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) { + if acc.Email == "acc2@test.com" { + return "", errors.New("login failed") + } + return "new-token", nil + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if a.AccountID != "acc1@test.com" { + t.Fatalf("expected first account, got %q", a.AccountID) + } + if !r.SwitchAccount(context.Background(), a) { + t.Fatal("expected switch to succeed after skipping failed account") + } + if a.AccountID != "acc3@test.com" { + t.Fatalf("expected fallback to third account, got %q", a.AccountID) + } + if !a.TriedAccounts["acc2@test.com"] { + t.Fatalf("expected failed account to be marked as tried") + } +} + // ─── Release edge cases ───────────────────────────────────────────── func TestReleaseNilAuth(t *testing.T) { diff --git a/internal/auth/request.go b/internal/auth/request.go index fa39c61..9a147d2 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -70,25 +70,45 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { }, nil } target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) - acc, ok := r.Pool.AcquireWait(ctx, target, nil) - if !ok { - return nil, ErrNoAccount - } - a := &RequestAuth{ - UseConfigToken: true, - CallerID: callerID, - AccountID: acc.Identifier(), - Account: acc, - TriedAccounts: map[string]bool{}, - resolver: r, - } - if err := r.ensureManagedToken(ctx, a); err != nil { - r.Pool.Release(a.AccountID) + a, err := r.acquireManagedRequestAuth(ctx, callerID, target) + if err != nil { return nil, err } return a, nil } +func (r *Resolver) acquireManagedRequestAuth(ctx context.Context, callerID, target string) (*RequestAuth, error) { + tried := map[string]bool{} + for { + if target == "" && len(tried) >= len(r.Store.Accounts()) { + return nil, ErrNoAccount + } + acc, ok := r.Pool.AcquireWait(ctx, target, tried) + if !ok { + return nil, ErrNoAccount + } + + a := &RequestAuth{ + UseConfigToken: true, + CallerID: callerID, + AccountID: acc.Identifier(), + Account: acc, + TriedAccounts: tried, + resolver: r, + } + + if err := r.ensureManagedToken(ctx, a); err != nil { + tried[a.AccountID] = true + r.Pool.Release(a.AccountID) + if target != "" { + return nil, err + } + continue + } + return a, nil + } +} + // DetermineCaller resolves caller identity without acquiring any pooled account. // Use this for local-cache lookup routes that only need tenant isolation. func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { @@ -164,16 +184,20 @@ func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool { a.TriedAccounts[a.AccountID] = true r.Pool.Release(a.AccountID) } - acc, ok := r.Pool.Acquire("", a.TriedAccounts) - if !ok { - return false + for { + acc, ok := r.Pool.Acquire("", a.TriedAccounts) + if !ok { + return false + } + a.Account = acc + a.AccountID = acc.Identifier() + if err := r.ensureManagedToken(ctx, a); err != nil { + a.TriedAccounts[a.AccountID] = true + r.Pool.Release(a.AccountID) + continue + } + return true } - a.Account = acc - a.AccountID = acc.Identifier() - if err := r.ensureManagedToken(ctx, a); err != nil { - return false - } - return true } func (r *Resolver) Release(a *RequestAuth) { diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index eab97a4..d8f36b3 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "net/http" "sync/atomic" "testing" @@ -301,3 +302,66 @@ func TestDetermineManagedAccountUsesUpdatedRefreshInterval(t *testing.T) { t.Fatalf("expected exactly one login after runtime update, got %d", got) } } + +func TestDetermineManagedAccountRetriesOtherAccountOnLoginFailure(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"bad@example.com","password":"pwd"}, + {"email":"good@example.com","password":"pwd","token":"good-token"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) { + if acc.Email == "bad@example.com" { + return "", errors.New("stale account") + } + return "fresh-good-token", nil + }) + + req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer resolver.Release(a) + if a.AccountID != "good@example.com" { + t.Fatalf("expected fallback to good account, got %q", a.AccountID) + } + if a.DeepSeekToken == "" { + t.Fatal("expected non-empty token from fallback account") + } + if !a.TriedAccounts["bad@example.com"] { + t.Fatalf("expected bad account to be tracked as tried") + } +} + +func TestDetermineTargetAccountDoesNotFallbackOnLoginFailure(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"bad@example.com","password":"pwd"}, + {"email":"good@example.com","password":"pwd","token":"good-token"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) { + if acc.Email == "bad@example.com" { + return "", errors.New("stale account") + } + return "fresh-good-token", nil + }) + + req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("x-api-key", "managed-key") + req.Header.Set("X-Ds2-Target-Account", "bad@example.com") + + _, err := resolver.Determine(req) + if err == nil { + t.Fatal("expected determine to fail for broken target account") + } +} From e60738b084a2a8c9a3e442a3946f6d95a121d828 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Thu, 2 Apr 2026 12:58:09 +0800 Subject: [PATCH 2/2] auth: preserve ensure error after retry exhaustion --- internal/auth/request.go | 8 ++++++++ internal/auth/request_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/internal/auth/request.go b/internal/auth/request.go index 9a147d2..e6a0d88 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -79,12 +79,19 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { func (r *Resolver) acquireManagedRequestAuth(ctx context.Context, callerID, target string) (*RequestAuth, error) { tried := map[string]bool{} + var lastEnsureErr error for { if target == "" && len(tried) >= len(r.Store.Accounts()) { + if lastEnsureErr != nil { + return nil, lastEnsureErr + } return nil, ErrNoAccount } acc, ok := r.Pool.AcquireWait(ctx, target, tried) if !ok { + if lastEnsureErr != nil { + return nil, lastEnsureErr + } return nil, ErrNoAccount } @@ -98,6 +105,7 @@ func (r *Resolver) acquireManagedRequestAuth(ctx context.Context, callerID, targ } if err := r.ensureManagedToken(ctx, a); err != nil { + lastEnsureErr = err tried[a.AccountID] = true r.Pool.Release(a.AccountID) if target != "" { diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index d8f36b3..edf163d 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -365,3 +365,33 @@ func TestDetermineTargetAccountDoesNotFallbackOnLoginFailure(t *testing.T) { t.Fatal("expected determine to fail for broken target account") } } + +func TestDetermineManagedAccountReturnsLastEnsureErrorWhenAllFail(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"bad1@example.com","password":"pwd"}, + {"email":"bad2@example.com","password":"pwd"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + ensureErr := errors.New("all credentials stale") + resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "", ensureErr + }) + + req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("x-api-key", "managed-key") + + _, err := resolver.Determine(req) + if err == nil { + t.Fatal("expected determine to fail") + } + if !errors.Is(err, ensureErr) { + t.Fatalf("expected ensure error, got %v", err) + } + if errors.Is(err, ErrNoAccount) { + t.Fatalf("expected auth-style ensure error, got ErrNoAccount") + } +}