refactor: replace WASM-based PoW solver with a native Go implementation in the pow package

This commit is contained in:
CJACK
2026-04-07 00:10:36 +08:00
parent 13687ce787
commit e7d561694a
18 changed files with 465 additions and 350 deletions

View File

@@ -33,10 +33,6 @@ func ConfigPath() string {
return ResolvePath("DS2API_CONFIG_PATH", "config.json")
}
func WASMPath() string {
return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
}
func RawStreamSampleRoot() string {
return ResolvePath("DS2API_RAW_STREAM_SAMPLE_ROOT", "tests/raw_stream_samples")
}

View File

@@ -109,7 +109,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
challenge, _ := bizData["challenge"].(map[string]any)
answer, err := c.powSolver.Compute(ctx, challenge)
answer, err := ComputePow(challenge)
if err != nil {
attempts++
continue

View File

@@ -23,7 +23,6 @@ type Client struct {
stream trans.Doer
fallback *http.Client
fallbackS *http.Client
powSolver *PowSolver
maxRetries int
}
@@ -36,11 +35,11 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
fallbackS: &http.Client{Timeout: 0},
powSolver: NewPowSolver(config.WASMPath()),
maxRetries: 3,
}
}
func (c *Client) PreloadPow(ctx context.Context) error {
return c.powSolver.init(ctx)
// PreloadPow 保留兼容接口,纯 Go 实现无需预加载。
func (c *Client) PreloadPow(_ context.Context) error {
return nil
}

View File

@@ -105,43 +105,16 @@ func TestBuildPowHeaderEmptyChallenge(t *testing.T) {
}
}
// ─── PowSolver pool size ─────────────────────────────────────────────
func TestPowPoolSizeFromEnvDefault(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "")
got := powPoolSizeFromEnv()
if got < 1 {
t.Fatalf("expected positive default pool size, got %d", got)
}
}
func TestPowPoolSizeFromEnvInvalid(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "abc")
got := powPoolSizeFromEnv()
if got < 1 {
t.Fatalf("expected positive default for invalid, got %d", got)
}
}
func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "5")
got := powPoolSizeFromEnv()
if got != 5 {
t.Fatalf("expected 5, got %d", got)
}
}
// ─── NewClient ───────────────────────────────────────────────────────
func TestNewClientInitialState(t *testing.T) {
client := NewClient(nil, nil)
if client.powSolver == nil {
t.Fatal("expected powSolver to be initialized")
if client == nil {
t.Fatal("expected non-nil client")
}
}
func TestNewClientPreloadPowIdempotent(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "1")
client := NewClient(nil, nil)
if err := client.PreloadPow(context.Background()); err != nil {
t.Fatalf("first preload failed: %v", err)
@@ -150,16 +123,3 @@ func TestNewClientPreloadPowIdempotent(t *testing.T) {
t.Fatalf("second preload failed: %v", err)
}
}
// ─── PowSolver init and module pool ──────────────────────────────────
func TestPowSolverPoolSizeMatchesEnv(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "2")
solver := NewPowSolver("test.wasm")
if err := solver.init(context.Background()); err != nil {
t.Fatalf("init failed: %v", err)
}
if cap(solver.pool) != 2 {
t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool))
}
}

View File

@@ -1,6 +0,0 @@
package deepseek
import _ "embed"
//go:embed assets/sha3_wasm_bg.7b9ca65ddd.wasm
var embeddedWASM []byte

View File

@@ -1,220 +1,28 @@
package deepseek
import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"math"
"os"
stdruntime "runtime"
"strconv"
"sync"
"ds2api/internal/config"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"ds2api/pow"
)
type PowSolver struct {
wasmPath string
once sync.Once
err error
runtime wazero.Runtime
compiled wazero.CompiledModule
pool chan *pooledModule
poolSize int
}
type pooledModule struct {
mod api.Module
stackFn api.Function
allocFn api.Function
freeFn api.Function
solveFn api.Function
}
func NewPowSolver(wasmPath string) *PowSolver {
return &PowSolver{wasmPath: wasmPath}
}
func (p *PowSolver) init(ctx context.Context) error {
p.once.Do(func() {
wasmBytes, err := os.ReadFile(p.wasmPath)
if err != nil {
if len(embeddedWASM) == 0 {
p.err = err
return
}
wasmBytes = embeddedWASM
}
p.runtime = wazero.NewRuntime(ctx)
p.compiled, p.err = p.runtime.CompileModule(ctx, wasmBytes)
if p.err == nil {
p.poolSize = powPoolSizeFromEnv()
p.pool = make(chan *pooledModule, p.poolSize)
for range p.poolSize {
inst, err := p.createModule(ctx)
if err != nil {
p.err = err
return
}
p.pool <- inst
}
}
})
return p.err
}
func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int64, error) {
if err := p.init(ctx); err != nil {
return 0, err
}
// ComputePow 使用纯 Go 实现求解 PoW challenge (DeepSeekHashV1)。
func ComputePow(challenge map[string]any) (int64, error) {
algo, _ := challenge["algorithm"].(string)
if algo != "DeepSeekHashV1" {
return 0, errors.New("unsupported algorithm")
}
challengeStr, _ := challenge["challenge"].(string)
salt, _ := challenge["salt"].(string)
signature, _ := challenge["signature"].(string)
targetPath, _ := challenge["target_path"].(string)
_ = signature
_ = targetPath
difficulty := toFloat64(challenge["difficulty"], 144000)
expireAt := toInt64(challenge["expire_at"], 1680000000)
prefix := salt + "_" + itoa(expireAt) + "_"
difficulty := toInt64FromFloat(challenge["difficulty"], 144000)
pm, err := p.acquireModule(ctx)
if err != nil {
return 0, err
}
defer p.releaseModule(pm)
mem := pm.mod.Memory()
if mem == nil {
return 0, errors.New("wasm memory missing")
}
retPtrs, err := pm.stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32
if err != nil || len(retPtrs) == 0 {
return 0, errors.New("stack alloc failed")
}
retptr := uint32(retPtrs[0])
defer func() {
_, _ = pm.stackFn.Call(context.Background(), 16)
}()
chPtr, chLen, err := writeUTF8(ctx, pm.allocFn, mem, challengeStr)
if err != nil {
return 0, err
}
defer freeUTF8(pm.freeFn, chPtr, chLen)
prefixPtr, prefixLen, err := writeUTF8(ctx, pm.allocFn, mem, prefix)
if err != nil {
return 0, err
}
defer freeUTF8(pm.freeFn, prefixPtr, prefixLen)
if _, err := pm.solveFn.Call(ctx,
uint64(retptr),
uint64(chPtr), uint64(chLen),
uint64(prefixPtr), uint64(prefixLen),
math.Float64bits(difficulty),
); err != nil {
return 0, err
}
statusBytes, ok := mem.Read(retptr, 4)
if !ok {
return 0, errors.New("read status failed")
}
status := int32(binary.LittleEndian.Uint32(statusBytes))
valueBytes, ok := mem.Read(retptr+8, 8)
if !ok {
return 0, errors.New("read value failed")
}
value := math.Float64frombits(binary.LittleEndian.Uint64(valueBytes))
if status == 0 {
return 0, errors.New("pow solve failed")
}
return int64(value), nil
}
func (p *PowSolver) createModule(ctx context.Context) (*pooledModule, error) {
mod, err := p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig())
if err != nil {
return nil, err
}
stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer")
allocFn := mod.ExportedFunction("__wbindgen_export_0")
solveFn := mod.ExportedFunction("wasm_solve")
if stackFn == nil || allocFn == nil || solveFn == nil {
_ = mod.Close(context.Background())
return nil, errors.New("required wasm exports missing")
}
return &pooledModule{
mod: mod,
stackFn: stackFn,
allocFn: allocFn,
freeFn: mod.ExportedFunction("__wbindgen_export_2"),
solveFn: solveFn,
}, nil
}
func (p *PowSolver) acquireModule(ctx context.Context) (*pooledModule, error) {
if p.pool != nil {
for {
select {
case pm := <-p.pool:
if pm != nil {
return pm, nil
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
return p.createModule(ctx)
}
func (p *PowSolver) releaseModule(pm *pooledModule) {
if pm == nil || pm.mod == nil {
return
}
if p.pool != nil {
select {
case p.pool <- pm:
return
default:
}
}
_ = pm.mod.Close(context.Background())
}
func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text string) (uint32, uint32, error) {
data := []byte(text)
res, err := allocFn.Call(ctx, uint64(len(data)), 1)
if err != nil || len(res) == 0 {
return 0, 0, errors.New("alloc failed")
}
ptr := uint32(res[0])
if !mem.Write(ptr, data) {
return 0, 0, errors.New("mem write failed")
}
return ptr, uint32(len(data)), nil
}
func freeUTF8(freeFn api.Function, ptr, size uint32) {
if freeFn == nil || ptr == 0 || size == 0 {
return
}
_, _ = freeFn.Call(context.Background(), uint64(ptr), uint64(size), 1)
return pow.SolvePow(challengeStr, salt, expireAt, difficulty)
}
// BuildPowHeader 序列化 {algorithm,challenge,salt,answer,signature,target_path} 为 base64(JSON)。
func BuildPowHeader(challenge map[string]any, answer int64) (string, error) {
payload := map[string]any{
"algorithm": challenge["algorithm"],
@@ -257,32 +65,7 @@ func toInt64(v any, d int64) int64 {
}
}
func itoa(n int64) string {
return strconv.FormatInt(n, 10)
}
func powPoolSizeFromEnv() int {
const fallback = 4
n := fallback
if cpus := stdruntime.GOMAXPROCS(0); cpus > 0 {
n = cpus
}
if raw := os.Getenv("DS2API_POW_POOL_SIZE"); raw != "" {
if v, err := strconv.Atoi(raw); err == nil && v > 0 {
n = v
}
}
if n > 64 {
return 64
}
return n
}
func PreloadWASM(wasmPath string) {
solver := NewPowSolver(wasmPath)
if err := solver.init(context.Background()); err != nil {
config.Logger.Warn("[WASM] preload failed", "error", err)
return
}
config.Logger.Info("[WASM] module preloaded", "path", wasmPath)
// toInt64FromFloat 与 toInt64 等价,仅名称区分用途。
func toInt64FromFloat(v any, d int64) int64 {
return toInt64(v, d)
}

View File

@@ -3,66 +3,18 @@ package deepseek
import (
"context"
"testing"
"time"
)
func TestPowPoolSizeFromEnv(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "3")
if got := powPoolSizeFromEnv(); got != 3 {
t.Fatalf("expected pool size 3, got %d", got)
}
}
func TestPowSolverAcquireReleaseReusesModule(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "1")
solver := NewPowSolver("missing-file.wasm")
if err := solver.init(context.Background()); err != nil {
t.Fatalf("init failed: %v", err)
}
pm1, err := solver.acquireModule(context.Background())
if err != nil {
t.Fatalf("acquire first module failed: %v", err)
}
solver.releaseModule(pm1)
pm2, err := solver.acquireModule(context.Background())
if err != nil {
t.Fatalf("acquire second module failed: %v", err)
}
if pm1 != pm2 {
t.Fatalf("expected pooled module reuse, got different instances")
}
solver.releaseModule(pm2)
}
func TestPowSolverAcquireHonorsContextWhenPoolExhausted(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "1")
solver := NewPowSolver("missing-file.wasm")
if err := solver.init(context.Background()); err != nil {
t.Fatalf("init failed: %v", err)
}
held, err := solver.acquireModule(context.Background())
if err != nil {
t.Fatalf("acquire held module failed: %v", err)
}
defer solver.releaseModule(held)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
if _, err := solver.acquireModule(ctx); err == nil {
t.Fatalf("expected context cancellation while pool is exhausted")
}
}
func TestClientPreloadPowUsesClientSolver(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "1")
func TestPreloadPowNoOp(t *testing.T) {
client := NewClient(nil, nil)
if err := client.PreloadPow(context.Background()); err != nil {
t.Fatalf("preload failed: %v", err)
}
if client.powSolver.runtime == nil || client.powSolver.compiled == nil {
t.Fatalf("expected client pow solver to be initialized")
t.Fatalf("PreloadPow should be no-op, got error: %v", err)
}
}
func TestComputePowUnsupportedAlgorithm(t *testing.T) {
_, err := ComputePow(map[string]any{"algorithm": "unknown"})
if err == nil {
t.Fatal("expected error for unsupported algorithm")
}
}

View File

@@ -42,9 +42,9 @@ func NewApp() (*App, error) {
})
dsClient = deepseek.NewClient(store, resolver)
if err := dsClient.PreloadPow(context.Background()); err != nil {
config.Logger.Warn("[WASM] preload failed", "error", err)
config.Logger.Warn("[PoW] init failed", "error", err)
} else {
config.Logger.Info("[WASM] module preloaded", "path", config.WASMPath())
config.Logger.Info("[PoW] pure Go solver ready")
}
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}