Merge pull request #184 from CJackHwang/codex/refactor-acquire-to-handle-empty-token-accounts

auth: retry other managed accounts when token ensure fails
This commit is contained in:
CJACK.
2026-04-02 13:00:18 +08:00
committed by GitHub
3 changed files with 188 additions and 23 deletions

View File

@@ -204,6 +204,45 @@ func TestSwitchAccountNilTriedAccounts(t *testing.T) {
r.Release(a)
}
func TestSwitchAccountSkipsLoginFailureAndContinues(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"acc1@test.com","password":"pwd","token":"t1"},
{"email":"acc2@test.com","password":"pwd"},
{"email":"acc3@test.com","password":"pwd","token":"t3"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
if acc.Email == "acc2@test.com" {
return "", errors.New("login failed")
}
return "new-token", nil
})
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
if a.AccountID != "acc1@test.com" {
t.Fatalf("expected first account, got %q", a.AccountID)
}
if !r.SwitchAccount(context.Background(), a) {
t.Fatal("expected switch to succeed after skipping failed account")
}
if a.AccountID != "acc3@test.com" {
t.Fatalf("expected fallback to third account, got %q", a.AccountID)
}
if !a.TriedAccounts["acc2@test.com"] {
t.Fatalf("expected failed account to be marked as tried")
}
}
// ─── Release edge cases ─────────────────────────────────────────────
func TestReleaseNilAuth(t *testing.T) {

View File

@@ -70,25 +70,53 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
}, nil
}
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
if !ok {
return nil, ErrNoAccount
}
a := &RequestAuth{
UseConfigToken: true,
CallerID: callerID,
AccountID: acc.Identifier(),
Account: acc,
TriedAccounts: map[string]bool{},
resolver: r,
}
if err := r.ensureManagedToken(ctx, a); err != nil {
r.Pool.Release(a.AccountID)
a, err := r.acquireManagedRequestAuth(ctx, callerID, target)
if err != nil {
return nil, err
}
return a, nil
}
func (r *Resolver) acquireManagedRequestAuth(ctx context.Context, callerID, target string) (*RequestAuth, error) {
tried := map[string]bool{}
var lastEnsureErr error
for {
if target == "" && len(tried) >= len(r.Store.Accounts()) {
if lastEnsureErr != nil {
return nil, lastEnsureErr
}
return nil, ErrNoAccount
}
acc, ok := r.Pool.AcquireWait(ctx, target, tried)
if !ok {
if lastEnsureErr != nil {
return nil, lastEnsureErr
}
return nil, ErrNoAccount
}
a := &RequestAuth{
UseConfigToken: true,
CallerID: callerID,
AccountID: acc.Identifier(),
Account: acc,
TriedAccounts: tried,
resolver: r,
}
if err := r.ensureManagedToken(ctx, a); err != nil {
lastEnsureErr = err
tried[a.AccountID] = true
r.Pool.Release(a.AccountID)
if target != "" {
return nil, err
}
continue
}
return a, nil
}
}
// DetermineCaller resolves caller identity without acquiring any pooled account.
// Use this for local-cache lookup routes that only need tenant isolation.
func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) {
@@ -164,16 +192,20 @@ func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool {
a.TriedAccounts[a.AccountID] = true
r.Pool.Release(a.AccountID)
}
acc, ok := r.Pool.Acquire("", a.TriedAccounts)
if !ok {
return false
for {
acc, ok := r.Pool.Acquire("", a.TriedAccounts)
if !ok {
return false
}
a.Account = acc
a.AccountID = acc.Identifier()
if err := r.ensureManagedToken(ctx, a); err != nil {
a.TriedAccounts[a.AccountID] = true
r.Pool.Release(a.AccountID)
continue
}
return true
}
a.Account = acc
a.AccountID = acc.Identifier()
if err := r.ensureManagedToken(ctx, a); err != nil {
return false
}
return true
}
func (r *Resolver) Release(a *RequestAuth) {

View File

@@ -2,6 +2,7 @@ package auth
import (
"context"
"errors"
"net/http"
"sync/atomic"
"testing"
@@ -301,3 +302,96 @@ func TestDetermineManagedAccountUsesUpdatedRefreshInterval(t *testing.T) {
t.Fatalf("expected exactly one login after runtime update, got %d", got)
}
}
func TestDetermineManagedAccountRetriesOtherAccountOnLoginFailure(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"bad@example.com","password":"pwd"},
{"email":"good@example.com","password":"pwd","token":"good-token"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
resolver := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
if acc.Email == "bad@example.com" {
return "", errors.New("stale account")
}
return "fresh-good-token", nil
})
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set("x-api-key", "managed-key")
a, err := resolver.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer resolver.Release(a)
if a.AccountID != "good@example.com" {
t.Fatalf("expected fallback to good account, got %q", a.AccountID)
}
if a.DeepSeekToken == "" {
t.Fatal("expected non-empty token from fallback account")
}
if !a.TriedAccounts["bad@example.com"] {
t.Fatalf("expected bad account to be tracked as tried")
}
}
func TestDetermineTargetAccountDoesNotFallbackOnLoginFailure(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"bad@example.com","password":"pwd"},
{"email":"good@example.com","password":"pwd","token":"good-token"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
resolver := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
if acc.Email == "bad@example.com" {
return "", errors.New("stale account")
}
return "fresh-good-token", nil
})
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set("x-api-key", "managed-key")
req.Header.Set("X-Ds2-Target-Account", "bad@example.com")
_, err := resolver.Determine(req)
if err == nil {
t.Fatal("expected determine to fail for broken target account")
}
}
func TestDetermineManagedAccountReturnsLastEnsureErrorWhenAllFail(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"bad1@example.com","password":"pwd"},
{"email":"bad2@example.com","password":"pwd"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
ensureErr := errors.New("all credentials stale")
resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "", ensureErr
})
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set("x-api-key", "managed-key")
_, err := resolver.Determine(req)
if err == nil {
t.Fatal("expected determine to fail")
}
if !errors.Is(err, ensureErr) {
t.Fatalf("expected ensure error, got %v", err)
}
if errors.Is(err, ErrNoAccount) {
t.Fatalf("expected auth-style ensure error, got ErrNoAccount")
}
}