mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
refactor: Enhance WASM POW solver with channel-based pooling and configurable size, update token estimation, and fix CORS origin reflection.
This commit is contained in:
@@ -116,4 +116,4 @@
|
||||
|
||||
1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。
|
||||
2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。
|
||||
3. **Phase 3 (Cleanup):** 优化 CORS,改进 Token 估算等微小性能点。
|
||||
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 改为动态反射 Origin;Token 估算区分 ASCII/非 ASCII 字符。
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/server"
|
||||
"ds2api/internal/webui"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
|
||||
func main() {
|
||||
webui.EnsureBuiltOnStartup()
|
||||
_ = auth.AdminKey()
|
||||
app := server.NewApp()
|
||||
port := strings.TrimSpace(os.Getenv("PORT"))
|
||||
if port == "" {
|
||||
|
||||
@@ -48,6 +48,10 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) PreloadPow(ctx context.Context) error {
|
||||
return c.powSolver.init(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
|
||||
payload := map[string]any{
|
||||
"password": strings.TrimSpace(acc.Password),
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
stdruntime "runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@@ -24,7 +25,16 @@ type PowSolver struct {
|
||||
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
pool sync.Pool
|
||||
pool chan *pooledModule
|
||||
poolSize int
|
||||
}
|
||||
|
||||
type pooledModule struct {
|
||||
mod api.Module
|
||||
stackFn api.Function
|
||||
allocFn api.Function
|
||||
freeFn api.Function
|
||||
solveFn api.Function
|
||||
}
|
||||
|
||||
func NewPowSolver(wasmPath string) *PowSolver {
|
||||
@@ -44,14 +54,15 @@ func (p *PowSolver) init(ctx context.Context) error {
|
||||
p.runtime = wazero.NewRuntime(ctx)
|
||||
p.compiled, p.err = p.runtime.CompileModule(ctx, wasmBytes)
|
||||
if p.err == nil {
|
||||
p.pool = sync.Pool{
|
||||
New: func() any {
|
||||
mod, err := p.runtime.InstantiateModule(context.Background(), p.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mod
|
||||
},
|
||||
p.poolSize = powPoolSizeFromEnv()
|
||||
p.pool = make(chan *pooledModule, p.poolSize)
|
||||
for range p.poolSize {
|
||||
inst, err := p.createModule(ctx)
|
||||
if err != nil {
|
||||
p.err = err
|
||||
return
|
||||
}
|
||||
p.pool <- inst
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -77,49 +88,38 @@ func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int6
|
||||
expireAt := toInt64(challenge["expire_at"], 1680000000)
|
||||
prefix := salt + "_" + itoa(expireAt) + "_"
|
||||
|
||||
// Try to get a pooled instance; fall back to creating a new one.
|
||||
var mod api.Module
|
||||
if pooled := p.pool.Get(); pooled != nil {
|
||||
mod = pooled.(api.Module)
|
||||
} else {
|
||||
var err error
|
||||
mod, err = p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
pm, err := p.acquireModule(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// WASM instances may carry state; close after use rather than returning to pool.
|
||||
// The pool's New func will create fresh instances as needed.
|
||||
defer mod.Close(ctx)
|
||||
defer p.releaseModule(pm)
|
||||
|
||||
mem := mod.Memory()
|
||||
mem := pm.mod.Memory()
|
||||
if mem == nil {
|
||||
return 0, errors.New("wasm memory missing")
|
||||
}
|
||||
stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer")
|
||||
allocFn := mod.ExportedFunction("__wbindgen_export_0")
|
||||
solveFn := mod.ExportedFunction("wasm_solve")
|
||||
if stackFn == nil || allocFn == nil || solveFn == nil {
|
||||
return 0, errors.New("required wasm exports missing")
|
||||
}
|
||||
|
||||
retPtrs, err := stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32
|
||||
retPtrs, err := pm.stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32
|
||||
if err != nil || len(retPtrs) == 0 {
|
||||
return 0, errors.New("stack alloc failed")
|
||||
}
|
||||
retptr := uint32(retPtrs[0])
|
||||
defer stackFn.Call(ctx, 16)
|
||||
defer func() {
|
||||
_, _ = pm.stackFn.Call(context.Background(), 16)
|
||||
}()
|
||||
|
||||
chPtr, chLen, err := writeUTF8(ctx, allocFn, mem, challengeStr)
|
||||
chPtr, chLen, err := writeUTF8(ctx, pm.allocFn, mem, challengeStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
prefixPtr, prefixLen, err := writeUTF8(ctx, allocFn, mem, prefix)
|
||||
defer freeUTF8(pm.freeFn, chPtr, chLen)
|
||||
|
||||
prefixPtr, prefixLen, err := writeUTF8(ctx, pm.allocFn, mem, prefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer freeUTF8(pm.freeFn, prefixPtr, prefixLen)
|
||||
|
||||
if _, err := solveFn.Call(ctx,
|
||||
if _, err := pm.solveFn.Call(ctx,
|
||||
uint64(retptr),
|
||||
uint64(chPtr), uint64(chLen),
|
||||
uint64(prefixPtr), uint64(prefixLen),
|
||||
@@ -144,6 +144,54 @@ func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int6
|
||||
return int64(value), nil
|
||||
}
|
||||
|
||||
func (p *PowSolver) createModule(ctx context.Context) (*pooledModule, error) {
|
||||
mod, err := p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer")
|
||||
allocFn := mod.ExportedFunction("__wbindgen_export_0")
|
||||
solveFn := mod.ExportedFunction("wasm_solve")
|
||||
if stackFn == nil || allocFn == nil || solveFn == nil {
|
||||
_ = mod.Close(context.Background())
|
||||
return nil, errors.New("required wasm exports missing")
|
||||
}
|
||||
return &pooledModule{
|
||||
mod: mod,
|
||||
stackFn: stackFn,
|
||||
allocFn: allocFn,
|
||||
freeFn: mod.ExportedFunction("__wbindgen_export_2"),
|
||||
solveFn: solveFn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PowSolver) acquireModule(ctx context.Context) (*pooledModule, error) {
|
||||
if p.pool != nil {
|
||||
select {
|
||||
case pm := <-p.pool:
|
||||
if pm != nil {
|
||||
return pm, nil
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
return p.createModule(ctx)
|
||||
}
|
||||
|
||||
func (p *PowSolver) releaseModule(pm *pooledModule) {
|
||||
if pm == nil || pm.mod == nil {
|
||||
return
|
||||
}
|
||||
if p.pool != nil {
|
||||
select {
|
||||
case p.pool <- pm:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
_ = pm.mod.Close(context.Background())
|
||||
}
|
||||
|
||||
func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text string) (uint32, uint32, error) {
|
||||
data := []byte(text)
|
||||
res, err := allocFn.Call(ctx, uint64(len(data)), 1)
|
||||
@@ -157,6 +205,13 @@ func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text s
|
||||
return ptr, uint32(len(data)), nil
|
||||
}
|
||||
|
||||
func freeUTF8(freeFn api.Function, ptr, size uint32) {
|
||||
if freeFn == nil || ptr == 0 || size == 0 {
|
||||
return
|
||||
}
|
||||
_, _ = freeFn.Call(context.Background(), uint64(ptr), uint64(size), 1)
|
||||
}
|
||||
|
||||
func BuildPowHeader(challenge map[string]any, answer int64) (string, error) {
|
||||
payload := map[string]any{
|
||||
"algorithm": challenge["algorithm"],
|
||||
@@ -203,6 +258,23 @@ func itoa(n int64) string {
|
||||
return strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
func powPoolSizeFromEnv() int {
|
||||
const fallback = 4
|
||||
n := fallback
|
||||
if cpus := stdruntime.GOMAXPROCS(0); cpus > 0 {
|
||||
n = cpus
|
||||
}
|
||||
if raw := os.Getenv("DS2API_POW_POOL_SIZE"); raw != "" {
|
||||
if v, err := strconv.Atoi(raw); err == nil && v > 0 {
|
||||
n = v
|
||||
}
|
||||
}
|
||||
if n > 64 {
|
||||
return 64
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func PreloadWASM(wasmPath string) {
|
||||
solver := NewPowSolver(wasmPath)
|
||||
if err := solver.init(context.Background()); err != nil {
|
||||
|
||||
47
internal/deepseek/pow_test.go
Normal file
47
internal/deepseek/pow_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPowPoolSizeFromEnv(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "3")
|
||||
if got := powPoolSizeFromEnv(); got != 3 {
|
||||
t.Fatalf("expected pool size 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPowSolverAcquireReleaseReusesModule(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "1")
|
||||
solver := NewPowSolver("missing-file.wasm")
|
||||
if err := solver.init(context.Background()); err != nil {
|
||||
t.Fatalf("init failed: %v", err)
|
||||
}
|
||||
|
||||
pm1, err := solver.acquireModule(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("acquire first module failed: %v", err)
|
||||
}
|
||||
solver.releaseModule(pm1)
|
||||
|
||||
pm2, err := solver.acquireModule(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("acquire second module failed: %v", err)
|
||||
}
|
||||
if pm1 != pm2 {
|
||||
t.Fatalf("expected pooled module reuse, got different instances")
|
||||
}
|
||||
solver.releaseModule(pm2)
|
||||
}
|
||||
|
||||
func TestClientPreloadPowUsesClientSolver(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "1")
|
||||
client := NewClient(nil, nil)
|
||||
if err := client.PreloadPow(context.Background()); err != nil {
|
||||
t.Fatalf("preload failed: %v", err)
|
||||
}
|
||||
if client.powSolver.runtime == nil || client.powSolver.compiled == nil {
|
||||
t.Fatalf("expected client pow solver to be initialized")
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,11 @@ func NewApp() *App {
|
||||
return dsClient.Login(ctx, acc)
|
||||
})
|
||||
dsClient = deepseek.NewClient(store, resolver)
|
||||
deepseek.PreloadWASM(config.WASMPath())
|
||||
if err := dsClient.PreloadPow(context.Background()); err != nil {
|
||||
config.Logger.Warn("[WASM] preload failed", "error", err)
|
||||
} else {
|
||||
config.Logger.Info("[WASM] module preloaded", "path", config.WASMPath())
|
||||
}
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
@@ -86,8 +90,18 @@ func timeout(d time.Duration) func(http.Handler) http.Handler {
|
||||
|
||||
func cors(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
// Dynamically reflect the request origin to allow credentials.
|
||||
// Using "*" with Access-Control-Allow-Credentials: true is
|
||||
// invalid per the CORS spec and will be rejected by browsers.
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
} else {
|
||||
// No Origin header (e.g. server-to-server requests); allow all.
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
if r.Method == http.MethodOptions {
|
||||
|
||||
@@ -115,11 +115,25 @@ func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[
|
||||
return out
|
||||
}
|
||||
|
||||
// EstimateTokens provides a rough token count approximation.
|
||||
// For ASCII text (English, code, etc.) we use ~4 chars per token.
|
||||
// For non-ASCII text (Chinese, Japanese, Korean, etc.) we use ~1.3 chars per token,
|
||||
// which better reflects typical BPE tokenizer behavior for CJK scripts.
|
||||
func EstimateTokens(text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
n := len([]rune(text)) / 4
|
||||
asciiChars := 0
|
||||
nonASCIIChars := 0
|
||||
for _, r := range text {
|
||||
if r < 128 {
|
||||
asciiChars++
|
||||
} else {
|
||||
nonASCIIChars++
|
||||
}
|
||||
}
|
||||
// ASCII: ~4 chars per token; non-ASCII (CJK): ~1.3 chars per token
|
||||
n := asciiChars/4 + (nonASCIIChars*10+7)/13
|
||||
if n < 1 {
|
||||
return 1
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user