diff --git a/OPTIMIZATION_REPORT.md b/OPTIMIZATION_REPORT.md index 1fda1a6..8054795 100644 --- a/OPTIMIZATION_REPORT.md +++ b/OPTIMIZATION_REPORT.md @@ -116,4 +116,4 @@ 1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。 2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。 -3. **Phase 3 (Cleanup):** 优化 CORS,改进 Token 估算等微小性能点。 +3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 改为动态反射 Origin;Token 估算区分 ASCII/非 ASCII 字符。 diff --git a/cmd/ds2api/main.go b/cmd/ds2api/main.go index 2525fbb..8e83008 100644 --- a/cmd/ds2api/main.go +++ b/cmd/ds2api/main.go @@ -9,6 +9,7 @@ import ( "syscall" "time" + "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/server" "ds2api/internal/webui" @@ -16,6 +17,7 @@ import ( func main() { webui.EnsureBuiltOnStartup() + _ = auth.AdminKey() app := server.NewApp() port := strings.TrimSpace(os.Getenv("PORT")) if port == "" { diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go index c42443d..0523435 100644 --- a/internal/deepseek/client.go +++ b/internal/deepseek/client.go @@ -48,6 +48,10 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client { } } +func (c *Client) PreloadPow(ctx context.Context) error { + return c.powSolver.init(ctx) +} + func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) { payload := map[string]any{ "password": strings.TrimSpace(acc.Password), diff --git a/internal/deepseek/pow.go b/internal/deepseek/pow.go index 49ef837..5dda8cf 100644 --- a/internal/deepseek/pow.go +++ b/internal/deepseek/pow.go @@ -8,6 +8,7 @@ import ( "errors" "math" "os" + stdruntime "runtime" "strconv" "sync" @@ -24,7 +25,16 @@ type PowSolver struct { runtime wazero.Runtime compiled wazero.CompiledModule - pool sync.Pool + pool chan *pooledModule + poolSize int +} + +type pooledModule struct { + mod api.Module + stackFn api.Function + allocFn api.Function + freeFn api.Function + solveFn api.Function } func NewPowSolver(wasmPath string) *PowSolver { @@ -44,14 +54,15 @@ func (p *PowSolver) init(ctx context.Context) error { p.runtime = wazero.NewRuntime(ctx) p.compiled, p.err = p.runtime.CompileModule(ctx, wasmBytes) if p.err == nil { - p.pool = sync.Pool{ - New: func() any { - mod, err := p.runtime.InstantiateModule(context.Background(), p.compiled, wazero.NewModuleConfig()) - if err != nil { - return nil - } - return mod - }, + p.poolSize = powPoolSizeFromEnv() + p.pool = make(chan *pooledModule, p.poolSize) + for range p.poolSize { + inst, err := p.createModule(ctx) + if err != nil { + p.err = err + return + } + p.pool <- inst } } }) @@ -77,49 +88,38 @@ func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int6 expireAt := toInt64(challenge["expire_at"], 1680000000) prefix := salt + "_" + itoa(expireAt) + "_" - // Try to get a pooled instance; fall back to creating a new one. - var mod api.Module - if pooled := p.pool.Get(); pooled != nil { - mod = pooled.(api.Module) - } else { - var err error - mod, err = p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig()) - if err != nil { - return 0, err - } + pm, err := p.acquireModule(ctx) + if err != nil { + return 0, err } - // WASM instances may carry state; close after use rather than returning to pool. - // The pool's New func will create fresh instances as needed. - defer mod.Close(ctx) + defer p.releaseModule(pm) - mem := mod.Memory() + mem := pm.mod.Memory() if mem == nil { return 0, errors.New("wasm memory missing") } - stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer") - allocFn := mod.ExportedFunction("__wbindgen_export_0") - solveFn := mod.ExportedFunction("wasm_solve") - if stackFn == nil || allocFn == nil || solveFn == nil { - return 0, errors.New("required wasm exports missing") - } - - retPtrs, err := stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32 + retPtrs, err := pm.stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32 if err != nil || len(retPtrs) == 0 { return 0, errors.New("stack alloc failed") } retptr := uint32(retPtrs[0]) - defer stackFn.Call(ctx, 16) + defer func() { + _, _ = pm.stackFn.Call(context.Background(), 16) + }() - chPtr, chLen, err := writeUTF8(ctx, allocFn, mem, challengeStr) + chPtr, chLen, err := writeUTF8(ctx, pm.allocFn, mem, challengeStr) if err != nil { return 0, err } - prefixPtr, prefixLen, err := writeUTF8(ctx, allocFn, mem, prefix) + defer freeUTF8(pm.freeFn, chPtr, chLen) + + prefixPtr, prefixLen, err := writeUTF8(ctx, pm.allocFn, mem, prefix) if err != nil { return 0, err } + defer freeUTF8(pm.freeFn, prefixPtr, prefixLen) - if _, err := solveFn.Call(ctx, + if _, err := pm.solveFn.Call(ctx, uint64(retptr), uint64(chPtr), uint64(chLen), uint64(prefixPtr), uint64(prefixLen), @@ -144,6 +144,54 @@ func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int6 return int64(value), nil } +func (p *PowSolver) createModule(ctx context.Context) (*pooledModule, error) { + mod, err := p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig()) + if err != nil { + return nil, err + } + stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer") + allocFn := mod.ExportedFunction("__wbindgen_export_0") + solveFn := mod.ExportedFunction("wasm_solve") + if stackFn == nil || allocFn == nil || solveFn == nil { + _ = mod.Close(context.Background()) + return nil, errors.New("required wasm exports missing") + } + return &pooledModule{ + mod: mod, + stackFn: stackFn, + allocFn: allocFn, + freeFn: mod.ExportedFunction("__wbindgen_export_2"), + solveFn: solveFn, + }, nil +} + +func (p *PowSolver) acquireModule(ctx context.Context) (*pooledModule, error) { + if p.pool != nil { + select { + case pm := <-p.pool: + if pm != nil { + return pm, nil + } + default: + } + } + return p.createModule(ctx) +} + +func (p *PowSolver) releaseModule(pm *pooledModule) { + if pm == nil || pm.mod == nil { + return + } + if p.pool != nil { + select { + case p.pool <- pm: + return + default: + } + } + _ = pm.mod.Close(context.Background()) +} + func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text string) (uint32, uint32, error) { data := []byte(text) res, err := allocFn.Call(ctx, uint64(len(data)), 1) @@ -157,6 +205,13 @@ func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text s return ptr, uint32(len(data)), nil } +func freeUTF8(freeFn api.Function, ptr, size uint32) { + if freeFn == nil || ptr == 0 || size == 0 { + return + } + _, _ = freeFn.Call(context.Background(), uint64(ptr), uint64(size), 1) +} + func BuildPowHeader(challenge map[string]any, answer int64) (string, error) { payload := map[string]any{ "algorithm": challenge["algorithm"], @@ -203,6 +258,23 @@ func itoa(n int64) string { return strconv.FormatInt(n, 10) } +func powPoolSizeFromEnv() int { + const fallback = 4 + n := fallback + if cpus := stdruntime.GOMAXPROCS(0); cpus > 0 { + n = cpus + } + if raw := os.Getenv("DS2API_POW_POOL_SIZE"); raw != "" { + if v, err := strconv.Atoi(raw); err == nil && v > 0 { + n = v + } + } + if n > 64 { + return 64 + } + return n +} + func PreloadWASM(wasmPath string) { solver := NewPowSolver(wasmPath) if err := solver.init(context.Background()); err != nil { diff --git a/internal/deepseek/pow_test.go b/internal/deepseek/pow_test.go new file mode 100644 index 0000000..3e8af6c --- /dev/null +++ b/internal/deepseek/pow_test.go @@ -0,0 +1,47 @@ +package deepseek + +import ( + "context" + "testing" +) + +func TestPowPoolSizeFromEnv(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "3") + if got := powPoolSizeFromEnv(); got != 3 { + t.Fatalf("expected pool size 3, got %d", got) + } +} + +func TestPowSolverAcquireReleaseReusesModule(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + solver := NewPowSolver("missing-file.wasm") + if err := solver.init(context.Background()); err != nil { + t.Fatalf("init failed: %v", err) + } + + pm1, err := solver.acquireModule(context.Background()) + if err != nil { + t.Fatalf("acquire first module failed: %v", err) + } + solver.releaseModule(pm1) + + pm2, err := solver.acquireModule(context.Background()) + if err != nil { + t.Fatalf("acquire second module failed: %v", err) + } + if pm1 != pm2 { + t.Fatalf("expected pooled module reuse, got different instances") + } + solver.releaseModule(pm2) +} + +func TestClientPreloadPowUsesClientSolver(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + client := NewClient(nil, nil) + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("preload failed: %v", err) + } + if client.powSolver.runtime == nil || client.powSolver.compiled == nil { + t.Fatalf("expected client pow solver to be initialized") + } +} diff --git a/internal/server/router.go b/internal/server/router.go index 3b57392..e1260ce 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -36,7 +36,11 @@ func NewApp() *App { return dsClient.Login(ctx, acc) }) dsClient = deepseek.NewClient(store, resolver) - deepseek.PreloadWASM(config.WASMPath()) + if err := dsClient.PreloadPow(context.Background()); err != nil { + config.Logger.Warn("[WASM] preload failed", "error", err) + } else { + config.Logger.Info("[WASM] module preloaded", "path", config.WASMPath()) + } openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient} claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient} @@ -86,8 +90,18 @@ func timeout(d time.Duration) func(http.Handler) http.Handler { func cors(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Credentials", "true") + origin := r.Header.Get("Origin") + if origin != "" { + // Dynamically reflect the request origin to allow credentials. + // Using "*" with Access-Control-Allow-Credentials: true is + // invalid per the CORS spec and will be rejected by browsers. + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Vary", "Origin") + } else { + // No Origin header (e.g. server-to-server requests); allow all. + w.Header().Set("Access-Control-Allow-Origin", "*") + } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == http.MethodOptions { diff --git a/internal/util/messages.go b/internal/util/messages.go index 86052a9..69eaf87 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -115,11 +115,25 @@ func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[ return out } +// EstimateTokens provides a rough token count approximation. +// For ASCII text (English, code, etc.) we use ~4 chars per token. +// For non-ASCII text (Chinese, Japanese, Korean, etc.) we use ~1.3 chars per token, +// which better reflects typical BPE tokenizer behavior for CJK scripts. func EstimateTokens(text string) int { if text == "" { return 0 } - n := len([]rune(text)) / 4 + asciiChars := 0 + nonASCIIChars := 0 + for _, r := range text { + if r < 128 { + asciiChars++ + } else { + nonASCIIChars++ + } + } + // ASCII: ~4 chars per token; non-ASCII (CJK): ~1.3 chars per token + n := asciiChars/4 + (nonASCIIChars*10+7)/13 if n < 1 { return 1 }