mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-01 23:15:27 +08:00
413 lines
13 KiB
Go
413 lines
13 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"ds2api/internal/account"
|
|
"ds2api/internal/config"
|
|
)
|
|
|
|
// ─── extractCallerToken edge cases ───────────────────────────────────
|
|
|
|
func TestExtractCallerTokenBearerPrefix(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer my-token")
|
|
if got := extractCallerToken(req); got != "my-token" {
|
|
t.Fatalf("expected my-token, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "BEARER My-Token")
|
|
if got := extractCallerToken(req); got != "My-Token" {
|
|
t.Fatalf("expected My-Token, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestExtractCallerTokenBearerEmpty(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer ")
|
|
if got := extractCallerToken(req); got != "" {
|
|
t.Fatalf("expected empty for 'Bearer ', got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestExtractCallerTokenXAPIKey(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("x-api-key", "x-api-key-token")
|
|
if got := extractCallerToken(req); got != "x-api-key-token" {
|
|
t.Fatalf("expected x-api-key-token, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer bearer-token")
|
|
req.Header.Set("x-api-key", "x-api-key-token")
|
|
if got := extractCallerToken(req); got != "bearer-token" {
|
|
t.Fatalf("expected bearer-token, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestExtractCallerTokenMissingHeaders(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
if got := extractCallerToken(req); got != "" {
|
|
t.Fatalf("expected empty for missing headers, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestExtractCallerTokenNonBearerAuth(t *testing.T) {
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Basic abc123")
|
|
if got := extractCallerToken(req); got != "" {
|
|
t.Fatalf("expected empty for Basic auth, got %q", got)
|
|
}
|
|
}
|
|
|
|
// ─── Context helpers ─────────────────────────────────────────────────
|
|
|
|
func TestWithAuthAndFromContext(t *testing.T) {
|
|
a := &RequestAuth{DeepSeekToken: "test-token"}
|
|
ctx := WithAuth(context.Background(), a)
|
|
got, ok := FromContext(ctx)
|
|
if !ok || got.DeepSeekToken != "test-token" {
|
|
t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken)
|
|
}
|
|
}
|
|
|
|
func TestFromContextMissing(t *testing.T) {
|
|
_, ok := FromContext(context.Background())
|
|
if ok {
|
|
t.Fatal("expected not ok from empty context")
|
|
}
|
|
}
|
|
|
|
// ─── RefreshToken edge cases ─────────────────────────────────────────
|
|
|
|
func TestRefreshTokenNotConfigToken(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: false, resolver: r}
|
|
if r.RefreshToken(context.Background(), a) {
|
|
t.Fatal("expected false for non-config token")
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokenEmptyAccountID(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r}
|
|
if r.RefreshToken(context.Background(), a) {
|
|
t.Fatal("expected false for empty account ID")
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokenSuccess(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
// First acquire an account
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
a, err := r.Determine(req)
|
|
if err != nil {
|
|
t.Fatalf("determine failed: %v", err)
|
|
}
|
|
defer r.Release(a)
|
|
|
|
if !r.RefreshToken(context.Background(), a) {
|
|
t.Fatal("expected refresh to succeed")
|
|
}
|
|
if a.DeepSeekToken != "fresh-token" {
|
|
t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken)
|
|
}
|
|
}
|
|
|
|
// ─── MarkTokenInvalid edge cases ─────────────────────────────────────
|
|
|
|
func TestMarkTokenInvalidNotConfigToken(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r}
|
|
r.MarkTokenInvalid(a)
|
|
// Should not panic, token should be unchanged for non-config
|
|
_ = a.DeepSeekToken // Actual behavior may clear it; this test only asserts no panic.
|
|
}
|
|
|
|
func TestMarkTokenInvalidEmptyAccountID(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r}
|
|
r.MarkTokenInvalid(a)
|
|
// Should not panic
|
|
}
|
|
|
|
func TestMarkTokenInvalidClearsToken(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
a, err := r.Determine(req)
|
|
if err != nil {
|
|
t.Fatalf("determine failed: %v", err)
|
|
}
|
|
defer r.Release(a)
|
|
|
|
r.MarkTokenInvalid(a)
|
|
if a.DeepSeekToken != "" {
|
|
t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken)
|
|
}
|
|
if a.Account.Token != "" {
|
|
t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token)
|
|
}
|
|
}
|
|
|
|
// ─── SwitchAccount edge cases ────────────────────────────────────────
|
|
|
|
func TestSwitchAccountNotConfigToken(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: false, resolver: r}
|
|
if r.SwitchAccount(context.Background(), a) {
|
|
t.Fatal("expected false for non-config token")
|
|
}
|
|
}
|
|
|
|
func TestSwitchAccountNilTriedAccounts(t *testing.T) {
|
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
|
"keys":["managed-key"],
|
|
"accounts":[
|
|
{"email":"acc1@test.com","token":"t1"},
|
|
{"email":"acc2@test.com","token":"t2"}
|
|
]
|
|
}`)
|
|
store := config.LoadStore()
|
|
pool := account.NewPool(store)
|
|
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
|
return "new-token", nil
|
|
})
|
|
|
|
// First acquire
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
a, err := r.Determine(req)
|
|
if err != nil {
|
|
t.Fatalf("determine failed: %v", err)
|
|
}
|
|
|
|
oldID := a.AccountID
|
|
a.TriedAccounts = nil // test nil initialization in SwitchAccount
|
|
if !r.SwitchAccount(context.Background(), a) {
|
|
t.Fatal("expected switch to succeed")
|
|
}
|
|
if a.AccountID == oldID {
|
|
t.Fatalf("expected different account after switch")
|
|
}
|
|
r.Release(a)
|
|
}
|
|
|
|
func TestSwitchAccountSkipsLoginFailureAndContinues(t *testing.T) {
|
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
|
"keys":["managed-key"],
|
|
"accounts":[
|
|
{"email":"acc1@test.com","password":"pwd","token":"t1"},
|
|
{"email":"acc2@test.com","password":"pwd"},
|
|
{"email":"acc3@test.com","password":"pwd","token":"t3"}
|
|
]
|
|
}`)
|
|
store := config.LoadStore()
|
|
pool := account.NewPool(store)
|
|
r := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
|
|
if acc.Email == "acc2@test.com" {
|
|
return "", errors.New("login failed")
|
|
}
|
|
return "new-token", nil
|
|
})
|
|
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
a, err := r.Determine(req)
|
|
if err != nil {
|
|
t.Fatalf("determine failed: %v", err)
|
|
}
|
|
defer r.Release(a)
|
|
if a.AccountID != "acc1@test.com" {
|
|
t.Fatalf("expected first account, got %q", a.AccountID)
|
|
}
|
|
if !r.SwitchAccount(context.Background(), a) {
|
|
t.Fatal("expected switch to succeed after skipping failed account")
|
|
}
|
|
if a.AccountID != "acc3@test.com" {
|
|
t.Fatalf("expected fallback to third account, got %q", a.AccountID)
|
|
}
|
|
if !a.TriedAccounts["acc2@test.com"] {
|
|
t.Fatalf("expected failed account to be marked as tried")
|
|
}
|
|
}
|
|
|
|
// ─── Release edge cases ─────────────────────────────────────────────
|
|
|
|
func TestReleaseNilAuth(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
r.Release(nil) // should not panic
|
|
}
|
|
|
|
func TestReleaseNonConfigToken(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: false}
|
|
r.Release(a) // should not panic
|
|
}
|
|
|
|
func TestReleaseEmptyAccountID(t *testing.T) {
|
|
r := newTestResolver(t)
|
|
a := &RequestAuth{UseConfigToken: true, AccountID: ""}
|
|
r.Release(a) // should not panic
|
|
}
|
|
|
|
// ─── JWT edge cases ──────────────────────────────────────────────────
|
|
|
|
func TestVerifyJWTInvalidFormat(t *testing.T) {
|
|
_, err := VerifyJWT("not-a-jwt")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid JWT format")
|
|
}
|
|
}
|
|
|
|
func TestVerifyJWTInvalidSignature(t *testing.T) {
|
|
token, _ := CreateJWT(1)
|
|
// Tamper with the signature
|
|
parts := splitJWT(token)
|
|
if len(parts) == 3 {
|
|
tampered := parts[0] + "." + parts[1] + ".invalid_signature"
|
|
_, err := VerifyJWT(tampered)
|
|
if err == nil {
|
|
t.Fatal("expected error for tampered signature")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestVerifyJWTExpired(t *testing.T) {
|
|
// Create a token with 0 hours expiry - will use default, so we can't easily test
|
|
// Instead test with bad payload
|
|
_, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid")
|
|
if err == nil {
|
|
t.Fatal("expected error for expired/invalid JWT")
|
|
}
|
|
}
|
|
|
|
func TestCreateJWTDefaultExpiry(t *testing.T) {
|
|
token, err := CreateJWT(0) // should use default
|
|
if err != nil {
|
|
t.Fatalf("create jwt failed: %v", err)
|
|
}
|
|
_, err = VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("verify jwt failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// ─── VerifyAdminRequest edge cases ───────────────────────────────────
|
|
|
|
func TestVerifyAdminRequestNoHeader(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
|
if err := VerifyAdminRequest(req); err == nil {
|
|
t.Fatal("expected error for missing auth")
|
|
}
|
|
}
|
|
|
|
func TestVerifyAdminRequestEmptyBearer(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
|
req.Header.Set("Authorization", "Bearer ")
|
|
if err := VerifyAdminRequest(req); err == nil {
|
|
t.Fatal("expected error for empty bearer")
|
|
}
|
|
}
|
|
|
|
func TestVerifyAdminRequestWithAdminKey(t *testing.T) {
|
|
t.Setenv("DS2API_ADMIN_KEY", "test-admin-key")
|
|
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
|
req.Header.Set("Authorization", "Bearer test-admin-key")
|
|
if err := VerifyAdminRequest(req); err != nil {
|
|
t.Fatalf("expected admin key accepted: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestVerifyAdminRequestInvalidCredentials(t *testing.T) {
|
|
t.Setenv("DS2API_ADMIN_KEY", "correct-key")
|
|
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
|
req.Header.Set("Authorization", "Bearer wrong-key")
|
|
if err := VerifyAdminRequest(req); err == nil {
|
|
t.Fatal("expected error for wrong key")
|
|
}
|
|
}
|
|
|
|
func TestVerifyAdminRequestBasicAuth(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
|
req.Header.Set("Authorization", "Basic abc123")
|
|
if err := VerifyAdminRequest(req); err == nil {
|
|
t.Fatal("expected error for Basic auth")
|
|
}
|
|
}
|
|
|
|
// ─── Determine with login failure ────────────────────────────────────
|
|
|
|
func TestDetermineWithLoginFailure(t *testing.T) {
|
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
|
"keys":["managed-key"],
|
|
"accounts":[{"email":"acc@test.com","password":"pwd"}]
|
|
}`)
|
|
store := config.LoadStore()
|
|
pool := account.NewPool(store)
|
|
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
|
return "", errors.New("login failed")
|
|
})
|
|
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
_, err := r.Determine(req)
|
|
if err == nil {
|
|
t.Fatal("expected error when login fails")
|
|
}
|
|
}
|
|
|
|
// ─── Determine with target account ───────────────────────────────────
|
|
|
|
func TestDetermineWithTargetAccount(t *testing.T) {
|
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
|
"keys":["managed-key"],
|
|
"accounts":[
|
|
{"email":"acc1@test.com","token":"t1"},
|
|
{"email":"acc2@test.com","token":"t2"}
|
|
]
|
|
}`)
|
|
store := config.LoadStore()
|
|
pool := account.NewPool(store)
|
|
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
|
return "fresh-token", nil
|
|
})
|
|
|
|
req, _ := http.NewRequest("POST", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
req.Header.Set("X-Ds2-Target-Account", "acc2@test.com")
|
|
a, err := r.Determine(req)
|
|
if err != nil {
|
|
t.Fatalf("determine failed: %v", err)
|
|
}
|
|
defer r.Release(a)
|
|
if a.AccountID != "acc2@test.com" {
|
|
t.Fatalf("expected target account acc2, got %q", a.AccountID)
|
|
}
|
|
}
|
|
|
|
// helper
|
|
func splitJWT(token string) []string {
|
|
result := make([]string, 0, 3)
|
|
start := 0
|
|
count := 0
|
|
for i := 0; i < len(token); i++ {
|
|
if token[i] == '.' {
|
|
result = append(result, token[start:i])
|
|
start = i + 1
|
|
count++
|
|
}
|
|
}
|
|
result = append(result, token[start:])
|
|
return result
|
|
}
|