mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-12 04:07:42 +08:00
feat: Introduce a new Go-based DeepSeek API proxy with adapters for Claude and OpenAI, including SSE parsing and updated build configurations.
This commit is contained in:
113
internal/auth/admin.go
Normal file
113
internal/auth/admin.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AdminKey() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
|
||||
return v
|
||||
}
|
||||
return "your-admin-secret-key"
|
||||
}
|
||||
|
||||
func jwtSecret() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" {
|
||||
return v
|
||||
}
|
||||
return AdminKey()
|
||||
}
|
||||
|
||||
func jwtExpireHours() int {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 24
|
||||
}
|
||||
|
||||
func CreateJWT(expireHours int) (string, error) {
|
||||
if expireHours <= 0 {
|
||||
expireHours = jwtExpireHours()
|
||||
}
|
||||
header := map[string]any{"alg": "HS256", "typ": "JWT"}
|
||||
payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"}
|
||||
h, _ := json.Marshal(header)
|
||||
p, _ := json.Marshal(payload)
|
||||
headerB64 := rawB64Encode(h)
|
||||
payloadB64 := rawB64Encode(p)
|
||||
msg := headerB64 + "." + payloadB64
|
||||
sig := signHS256(msg)
|
||||
return msg + "." + rawB64Encode(sig), nil
|
||||
}
|
||||
|
||||
func VerifyJWT(token string) (map[string]any, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
msg := parts[0] + "." + parts[1]
|
||||
expected := signHS256(msg)
|
||||
actual, err := rawB64Decode(parts[2])
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid signature")
|
||||
}
|
||||
if !hmac.Equal(expected, actual) {
|
||||
return nil, errors.New("invalid signature")
|
||||
}
|
||||
payloadBytes, err := rawB64Decode(parts[1])
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid payload")
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, errors.New("invalid payload")
|
||||
}
|
||||
exp, _ := payload["exp"].(float64)
|
||||
if int64(exp) < time.Now().Unix() {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func VerifyAdminRequest(r *http.Request) error {
|
||||
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return errors.New("authentication required")
|
||||
}
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token == "" {
|
||||
return errors.New("authentication required")
|
||||
}
|
||||
if token == AdminKey() {
|
||||
return nil
|
||||
}
|
||||
if _, err := VerifyJWT(token); err == nil {
|
||||
return nil
|
||||
}
|
||||
return errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
func signHS256(msg string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(jwtSecret()))
|
||||
_, _ = h.Write([]byte(msg))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func rawB64Encode(b []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func rawB64Decode(s string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(s)
|
||||
}
|
||||
29
internal/auth/admin_test.go
Normal file
29
internal/auth/admin_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJWTCreateVerify(t *testing.T) {
|
||||
token, err := CreateJWT(1)
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt failed: %v", err)
|
||||
}
|
||||
payload, err := VerifyJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("verify jwt failed: %v", err)
|
||||
}
|
||||
if payload["role"] != "admin" {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAdminRequest(t *testing.T) {
|
||||
token, _ := CreateJWT(1)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/admin/config", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
if err := VerifyAdminRequest(req); err != nil {
|
||||
t.Fatalf("expected token accepted: %v", err)
|
||||
}
|
||||
}
|
||||
150
internal/auth/request.go
Normal file
150
internal/auth/request.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const authCtxKey ctxKey = "auth_context"
|
||||
|
||||
var (
|
||||
ErrUnauthorized = errors.New("unauthorized: missing Bearer token")
|
||||
ErrNoAccount = errors.New("no accounts configured or all accounts are busy")
|
||||
)
|
||||
|
||||
type RequestAuth struct {
|
||||
UseConfigToken bool
|
||||
DeepSeekToken string
|
||||
AccountID string
|
||||
Account config.Account
|
||||
TriedAccounts map[string]bool
|
||||
resolver *Resolver
|
||||
}
|
||||
|
||||
type LoginFunc func(ctx context.Context, acc config.Account) (string, error)
|
||||
|
||||
type Resolver struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
Login LoginFunc
|
||||
}
|
||||
|
||||
func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver {
|
||||
return &Resolver{Store: store, Pool: pool, Login: login}
|
||||
}
|
||||
|
||||
func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerKey := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||
ctx := req.Context()
|
||||
if !r.Store.HasAPIKey(callerKey) {
|
||||
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
|
||||
}
|
||||
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
|
||||
acc, ok := r.Pool.Acquire(target, nil)
|
||||
if !ok {
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: true,
|
||||
AccountID: acc.Identifier(),
|
||||
Account: acc,
|
||||
TriedAccounts: map[string]bool{},
|
||||
resolver: r,
|
||||
}
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func WithAuth(ctx context.Context, a *RequestAuth) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (*RequestAuth, bool) {
|
||||
v := ctx.Value(authCtxKey)
|
||||
a, ok := v.(*RequestAuth)
|
||||
return a, ok
|
||||
}
|
||||
|
||||
func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error {
|
||||
token, err := r.Login(ctx, a.Account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.Account.Token = token
|
||||
a.DeepSeekToken = token
|
||||
return r.Store.UpdateAccountToken(a.AccountID, token)
|
||||
}
|
||||
|
||||
func (r *Resolver) RefreshToken(ctx context.Context, a *RequestAuth) bool {
|
||||
if !a.UseConfigToken || a.AccountID == "" {
|
||||
return false
|
||||
}
|
||||
_ = r.Store.UpdateAccountToken(a.AccountID, "")
|
||||
a.Account.Token = ""
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
config.Logger.Error("[refresh_token] failed", "account", a.AccountID, "error", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Resolver) MarkTokenInvalid(a *RequestAuth) {
|
||||
if !a.UseConfigToken || a.AccountID == "" {
|
||||
return
|
||||
}
|
||||
a.Account.Token = ""
|
||||
a.DeepSeekToken = ""
|
||||
_ = r.Store.UpdateAccountToken(a.AccountID, "")
|
||||
}
|
||||
|
||||
func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool {
|
||||
if !a.UseConfigToken {
|
||||
return false
|
||||
}
|
||||
if a.TriedAccounts == nil {
|
||||
a.TriedAccounts = map[string]bool{}
|
||||
}
|
||||
if a.AccountID != "" {
|
||||
a.TriedAccounts[a.AccountID] = true
|
||||
r.Pool.Release(a.AccountID)
|
||||
}
|
||||
acc, ok := r.Pool.Acquire("", a.TriedAccounts)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
a.Account = acc
|
||||
a.AccountID = acc.Identifier()
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Resolver) Release(a *RequestAuth) {
|
||||
if a == nil || !a.UseConfigToken || a.AccountID == "" {
|
||||
return
|
||||
}
|
||||
r.Pool.Release(a.AccountID)
|
||||
}
|
||||
Reference in New Issue
Block a user