package auth import ( "context" "crypto/sha256" "encoding/hex" "errors" "net/http" "strings" "ds2api/internal/account" "ds2api/internal/config" ) type ctxKey string const authCtxKey ctxKey = "auth_context" var ( ErrUnauthorized = errors.New("unauthorized: missing auth token") ErrNoAccount = errors.New("no accounts configured or all accounts are busy") ) type RequestAuth struct { UseConfigToken bool DeepSeekToken string CallerID string AccountID string Account config.Account TriedAccounts map[string]bool resolver *Resolver } type LoginFunc func(ctx context.Context, acc config.Account) (string, error) type Resolver struct { Store *config.Store Pool *account.Pool Login LoginFunc } func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver { return &Resolver{Store: store, Pool: pool, Login: login} } func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { callerKey := extractCallerToken(req) if callerKey == "" { return nil, ErrUnauthorized } callerID := callerTokenID(callerKey) ctx := req.Context() if !r.Store.HasAPIKey(callerKey) { return &RequestAuth{ UseConfigToken: false, DeepSeekToken: callerKey, CallerID: callerID, resolver: r, TriedAccounts: map[string]bool{}, }, nil } target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) acc, ok := r.Pool.AcquireWait(ctx, target, nil) if !ok { return nil, ErrNoAccount } a := &RequestAuth{ UseConfigToken: true, CallerID: callerID, AccountID: acc.Identifier(), Account: acc, TriedAccounts: map[string]bool{}, resolver: r, } if acc.Token == "" { if err := r.loginAndPersist(ctx, a); err != nil { r.Pool.Release(a.AccountID) return nil, err } } else { a.DeepSeekToken = acc.Token } return a, nil } // DetermineCaller resolves caller identity without acquiring any pooled account. // Use this for local-cache lookup routes that only need tenant isolation. func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { callerKey := extractCallerToken(req) if callerKey == "" { return nil, ErrUnauthorized } callerID := callerTokenID(callerKey) a := &RequestAuth{ UseConfigToken: false, CallerID: callerID, resolver: r, TriedAccounts: map[string]bool{}, } if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) { a.DeepSeekToken = callerKey } return a, nil } func WithAuth(ctx context.Context, a *RequestAuth) context.Context { return context.WithValue(ctx, authCtxKey, a) } func FromContext(ctx context.Context) (*RequestAuth, bool) { v := ctx.Value(authCtxKey) a, ok := v.(*RequestAuth) return a, ok } func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error { token, err := r.Login(ctx, a.Account) if err != nil { return err } a.Account.Token = token a.DeepSeekToken = token return r.Store.UpdateAccountToken(a.AccountID, token) } func (r *Resolver) RefreshToken(ctx context.Context, a *RequestAuth) bool { if !a.UseConfigToken || a.AccountID == "" { return false } _ = r.Store.UpdateAccountToken(a.AccountID, "") a.Account.Token = "" if err := r.loginAndPersist(ctx, a); err != nil { config.Logger.Error("[refresh_token] failed", "account", a.AccountID, "error", err) return false } return true } func (r *Resolver) MarkTokenInvalid(a *RequestAuth) { if !a.UseConfigToken || a.AccountID == "" { return } a.Account.Token = "" a.DeepSeekToken = "" _ = r.Store.UpdateAccountToken(a.AccountID, "") } func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool { if !a.UseConfigToken { return false } if a.TriedAccounts == nil { a.TriedAccounts = map[string]bool{} } if a.AccountID != "" { a.TriedAccounts[a.AccountID] = true r.Pool.Release(a.AccountID) } acc, ok := r.Pool.Acquire("", a.TriedAccounts) if !ok { return false } a.Account = acc a.AccountID = acc.Identifier() if acc.Token == "" { if err := r.loginAndPersist(ctx, a); err != nil { return false } } else { a.DeepSeekToken = acc.Token } return true } func (r *Resolver) Release(a *RequestAuth) { if a == nil || !a.UseConfigToken || a.AccountID == "" { return } r.Pool.Release(a.AccountID) } func extractCallerToken(req *http.Request) string { authHeader := strings.TrimSpace(req.Header.Get("Authorization")) if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { token := strings.TrimSpace(authHeader[7:]) if token != "" { return token } } if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" { return key } // Gemini/Google clients commonly send API key via x-goog-api-key. if key := strings.TrimSpace(req.Header.Get("x-goog-api-key")); key != "" { return key } // Gemini AI Studio compatibility: allow query key fallback only when no // header-based credential is present. if key := strings.TrimSpace(req.URL.Query().Get("key")); key != "" { return key } return strings.TrimSpace(req.URL.Query().Get("api_key")) } func callerTokenID(token string) string { token = strings.TrimSpace(token) if token == "" { return "" } sum := sha256.Sum256([]byte(token)) return "caller:" + hex.EncodeToString(sum[:8]) }