DS2API
DeepSeek to OpenAI & Claude Compatible API
diff --git a/Dockerfile b/Dockerfile index 542feaa..ffb4fc4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,33 +1,24 @@ -# DS2API Docker 镜像 -# 采用极简、零侵入设计,所有配置通过环境变量传递 -# 主代码更新时只需重新构建镜像,无需修改 Dockerfile - FROM node:20 AS webui-builder WORKDIR /app/webui - COPY webui/package.json webui/package-lock.json ./ RUN npm ci - COPY webui ./ RUN npm run build -FROM python:3.11-slim - +FROM golang:1.25 AS go-builder WORKDIR /app - -# 安装依赖(利用 Docker 缓存层) -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt - -# 复制整个项目(保留原始目录结构) +COPY go.mod go.sum* ./ +RUN go mod download COPY . . +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o /out/ds2api ./cmd/ds2api -# 拷贝 WebUI 构建产物(非 Vercel / Docker 部署可直接使用) +FROM debian:bookworm-slim +WORKDIR /app +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates wget && rm -rf /var/lib/apt/lists/* +COPY --from=go-builder /out/ds2api /usr/local/bin/ds2api +COPY --from=go-builder /app/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm +COPY --from=go-builder /app/config.example.json /app/config.example.json COPY --from=webui-builder /app/static/admin /app/static/admin - -# 暴露服务端口 EXPOSE 5001 - -# 启动命令(依赖项目自身的启动逻辑) -CMD ["python", "app.py"] +CMD ["/usr/local/bin/ds2api"] diff --git a/README.MD b/README.MD index 98a28c3..21b7a10 100644 --- a/README.MD +++ b/README.MD @@ -70,15 +70,12 @@ git clone https://github.com/CJackHwang/ds2api.git cd ds2api -# 2. 安装依赖 -pip install -r requirements.txt - -# 3. 配置账号 +# 2. 准备配置 cp config.example.json config.json # 编辑 config.json,添加 DeepSeek 账号信息 -# 4. 启动服务 -python dev.py +# 3. 启动服务(Go 版本) +go run ./cmd/ds2api ``` 服务启动后访问 `http://localhost:5001` diff --git a/README.en.md b/README.en.md index f1b223c..2da100a 100644 --- a/README.en.md +++ b/README.en.md @@ -68,15 +68,12 @@ Convert DeepSeek Web into an **OpenAI & Claude compatible API**, with multi-acco git clone https://github.com/CJackHwang/ds2api.git cd ds2api -# 2. Install dependencies -pip install -r requirements.txt - -# 3. Configure accounts +# 2. Configure accounts cp config.example.json config.json # Edit config.json to add DeepSeek account info -# 4. Start the service -python dev.py +# 3. Start the service (Go runtime) +go run ./cmd/ds2api ``` Visit `http://localhost:5001` after startup. diff --git a/api/index.go b/api/index.go new file mode 100644 index 0000000..326e83f --- /dev/null +++ b/api/index.go @@ -0,0 +1,20 @@ +package handler + +import ( + "net/http" + "sync" + + "ds2api/internal/server" +) + +var ( + once sync.Once + app *server.App +) + +func Handler(w http.ResponseWriter, r *http.Request) { + once.Do(func() { + app = server.NewApp() + }) + app.Router.ServeHTTP(w, r) +} diff --git a/cmd/ds2api/main.go b/cmd/ds2api/main.go new file mode 100644 index 0000000..7c9ad68 --- /dev/null +++ b/cmd/ds2api/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "net/http" + "os" + "strings" + + "ds2api/internal/config" + "ds2api/internal/server" +) + +func main() { + app := server.NewApp() + port := strings.TrimSpace(os.Getenv("PORT")) + if port == "" { + port = "5001" + } + config.Logger.Info("starting ds2api", "port", port) + if err := http.ListenAndServe("0.0.0.0:"+port, app.Router); err != nil { + config.Logger.Error("server stopped", "error", err) + os.Exit(1) + } +} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a329cf8..c378129 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -9,22 +9,12 @@ services: ds2api: - build: . + build: + context: . + target: go-builder image: ds2api:dev container_name: ds2api-dev - command: [ - "uvicorn", - "app:app", - "--host", - "0.0.0.0", - "--port", - "5001", - "--reload", - "--reload-dir", - "/app", - "--log-level", - "debug" - ] + command: ["go", "run", "./cmd/ds2api"] ports: - "${PORT:-5001}:5001" env_file: @@ -34,10 +24,7 @@ services: - LOG_LEVEL=DEBUG volumes: # 源代码挂载(开发时实时生效) - - ./app.py:/app/app.py:ro - - ./core:/app/core:ro - - ./routes:/app/routes:ro - - ./static:/app/static:ro + - ./:/app # 配置文件挂载(便于本地修改) - ./config.json:/app/config.json restart: "no" diff --git a/docker-compose.yml b/docker-compose.yml index 3842060..d984420 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,13 +1,3 @@ -# DS2API 生产环境配置 -# 使用说明: -# 1. 复制 .env.example 为 .env 并填写配置 -# 2. docker-compose up -d -# 3. 主代码更新后:docker-compose up -d --build -# -# 设计原则: -# - 零侵入:所有项目配置通过 .env 文件传递 -# - 易维护:主代码更新只需重新构建镜像 - services: ds2api: build: . @@ -18,11 +8,10 @@ services: env_file: - .env environment: - # 确保容器内使用正确的主机绑定 - HOST=0.0.0.0 restart: unless-stopped healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:5001/v1/models')"] + test: ["CMD", "wget", "-qO-", "http://localhost:5001/healthz"] interval: 30s timeout: 10s retries: 3 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..712f09b --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module ds2api + +go 1.25 + +require ( + github.com/go-chi/chi/v5 v5.2.3 + github.com/google/uuid v1.6.0 + github.com/refraction-networking/utls v1.8.1 + github.com/tetratelabs/wazero v1.9.0 +) + +require ( + github.com/andybalholm/brotli v1.0.6 // indirect + github.com/klauspost/compress v1.17.4 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/sys v0.31.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..03b6d07 --- /dev/null +++ b/go.sum @@ -0,0 +1,16 @@ +github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= +github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= +github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/internal/account/pool.go b/internal/account/pool.go new file mode 100644 index 0000000..285299b --- /dev/null +++ b/internal/account/pool.go @@ -0,0 +1,127 @@ +package account + +import ( + "sort" + "sync" + + "ds2api/internal/config" +) + +type Pool struct { + store *config.Store + mu sync.Mutex + queue []string + inUse map[string]bool +} + +func NewPool(store *config.Store) *Pool { + p := &Pool{store: store, inUse: map[string]bool{}} + p.Reset() + return p +} + +func (p *Pool) Reset() { + accounts := p.store.Accounts() + sort.SliceStable(accounts, func(i, j int) bool { + iHas := accounts[i].Token != "" + jHas := accounts[j].Token != "" + if iHas == jHas { + return i < j + } + return iHas + }) + ids := make([]string, 0, len(accounts)) + for _, a := range accounts { + id := a.Identifier() + if id != "" { + ids = append(ids, id) + } + } + p.mu.Lock() + defer p.mu.Unlock() + p.queue = ids + p.inUse = map[string]bool{} + config.Logger.Info("[init_account_queue] initialized", "total", len(ids)) +} + +func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if target != "" { + for i, id := range p.queue { + if id != target { + continue + } + acc, ok := p.store.FindAccount(id) + if !ok { + return config.Account{}, false + } + p.queue = append(p.queue[:i], p.queue[i+1:]...) + p.inUse[id] = true + return acc, true + } + return config.Account{}, false + } + + for i := 0; i < len(p.queue); i++ { + id := p.queue[i] + if exclude[id] { + continue + } + acc, ok := p.store.FindAccount(id) + if !ok { + continue + } + if acc.Token == "" { + continue + } + p.queue = append(p.queue[:i], p.queue[i+1:]...) + p.inUse[id] = true + return acc, true + } + + for i := 0; i < len(p.queue); i++ { + id := p.queue[i] + if exclude[id] { + continue + } + acc, ok := p.store.FindAccount(id) + if !ok { + continue + } + p.queue = append(p.queue[:i], p.queue[i+1:]...) + p.inUse[id] = true + return acc, true + } + return config.Account{}, false +} + +func (p *Pool) Release(accountID string) { + if accountID == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if !p.inUse[accountID] { + return + } + delete(p.inUse, accountID) + p.queue = append(p.queue, accountID) +} + +func (p *Pool) Status() map[string]any { + p.mu.Lock() + defer p.mu.Unlock() + available := append([]string{}, p.queue...) + inUse := make([]string, 0, len(p.inUse)) + for id := range p.inUse { + inUse = append(inUse, id) + } + return map[string]any{ + "available": len(available), + "in_use": len(inUse), + "total": len(p.store.Accounts()), + "available_accounts": available, + "in_use_accounts": inUse, + } +} diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go new file mode 100644 index 0000000..391fd83 --- /dev/null +++ b/internal/adapter/claude/handler.go @@ -0,0 +1,403 @@ +package claude + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/sse" + "ds2api/internal/util" +) + +type Handler struct { + Store *config.Store + Auth *auth.Resolver + DS *deepseek.Client +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Get("/anthropic/v1/models", h.ListModels) + r.Post("/anthropic/v1/messages", h.Messages) + r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) +} + +func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, config.ClaudeModelsResponse()) +} + +func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}}) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}}) + return + } + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if model == "" || len(messagesRaw) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}}) + return + } + + normalized := normalizeClaudeMessages(messagesRaw) + payload := cloneMap(req) + payload["messages"] = normalized + toolsRequested, _ := req["tools"].([]any) + if len(toolsRequested) > 0 && !hasSystemMessage(normalized) { + payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...) + } + + dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store) + dsModel, _ := dsPayload["model"].(string) + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) + if !ok { + thinkingEnabled = false + searchEnabled = false + } + _ = searchEnabled + finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}}) + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}}) + return + } + requestPayload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": finalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": thinkingEnabled, + "search_enabled": searchEnabled, + } + resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}}) + return + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) + return + } + + fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled) + toolNames := extractClaudeToolNames(toolsRequested) + detected := util.ParseToolCalls(fullText, toolNames) + if toBool(req["stream"]) { + h.writeClaudeStream(w, r, model, normalized, fullText, detected) + return + } + content := make([]map[string]any, 0, 4) + if fullThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if fullText == "" { + fullText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": fullText}) + } + writeJSON(w, http.StatusOK, map[string]any{ + "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)), + "output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText), + }, + }) +} + +func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()}) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"}) + return + } + model, _ := req["model"].(string) + messages, _ := req["messages"].([]any) + if model == "" || len(messages) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."}) + return + } + inputTokens := 0 + if sys, ok := req["system"].(string); ok { + inputTokens += util.EstimateTokens(sys) + } + for _, item := range messages { + msg, ok := item.(map[string]any) + if !ok { + continue + } + inputTokens += 2 + inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) + } + if tools, ok := req["tools"].([]any); ok { + for _, t := range tools { + b, _ := json.Marshal(t) + inputTokens += util.EstimateTokens(string(b)) + } + } + if inputTokens < 1 { + inputTokens = 1 + } + writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) +} + +func collectDeepSeek(resp *http.Response, thinkingEnabled bool) (string, string) { + defer resp.Body.Close() + text := strings.Builder{} + thinking := strings.Builder{} + currentType := "text" + if thinkingEnabled { + currentType = "thinking" + } + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + chunk, done, ok := sse.ParseDeepSeekSSELine(scanner.Bytes()) + if !ok { + continue + } + if done { + break + } + parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) + currentType = newType + if finished { + break + } + for _, p := range parts { + if p.Type == "thinking" { + thinking.WriteString(p.Text) + } else { + text.WriteString(p.Text) + } + } + } + return text.String(), thinking.String() +} + +func (h *Handler) writeClaudeStream(w http.ResponseWriter, r *http.Request, model string, messages []any, fullText string, detected []util.ParsedToolCall) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + flusher, ok := w.(http.Flusher) + if !ok { + writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "streaming unsupported"}}) + return + } + send := func(v any) { + b, _ := json.Marshal(v) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(b) + _, _ = w.Write([]byte("\n\n")) + flusher.Flush() + } + messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) + inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages)) + send(map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, + }, + }) + outputTokens := 0 + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + send(map[string]any{"type": "content_block_start", "index": i, "content_block": map[string]any{"type": "tool_use", "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), "name": tc.Name, "input": tc.Input}}) + send(map[string]any{"type": "content_block_stop", "index": i}) + outputTokens += util.EstimateTokens(fmt.Sprintf("%v", tc.Input)) + } + } else { + if fullText != "" { + send(map[string]any{"type": "content_block_start", "index": 0, "content_block": map[string]any{"type": "text", "text": ""}}) + send(map[string]any{"type": "content_block_delta", "index": 0, "delta": map[string]any{"type": "text_delta", "text": fullText}}) + send(map[string]any{"type": "content_block_stop", "index": 0}) + outputTokens = util.EstimateTokens(fullText) + } + } + send(map[string]any{"type": "message_delta", "delta": map[string]any{"stop_reason": stopReason, "stop_sequence": nil}, "usage": map[string]any{"output_tokens": outputTokens}}) + send(map[string]any{"type": "message_stop"}) + _ = r +} + +func normalizeClaudeMessages(messages []any) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + msg, ok := m.(map[string]any) + if !ok { + continue + } + copied := cloneMap(msg) + switch content := msg["content"].(type) { + case []any: + parts := make([]string, 0, len(content)) + for _, block := range content { + b, ok := block.(map[string]any) + if !ok { + continue + } + typeStr, _ := b["type"].(string) + if typeStr == "text" { + if t, ok := b["text"].(string); ok { + parts = append(parts, t) + } + } + if typeStr == "tool_result" { + parts = append(parts, fmt.Sprintf("%v", b["content"])) + } + } + copied["content"] = strings.Join(parts, "\n") + } + out = append(out, copied) + } + return out +} + +func buildClaudeToolPrompt(tools []any) string { + parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"} + for _, t := range tools { + m, ok := t.(map[string]any) + if !ok { + continue + } + name, _ := m["name"].(string) + desc, _ := m["description"].(string) + schema, _ := json.Marshal(m["input_schema"]) + parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) + } + parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}") + return strings.Join(parts, "\n\n") +} + +func hasSystemMessage(messages []any) bool { + for _, m := range messages { + msg, ok := m.(map[string]any) + if ok && msg["role"] == "system" { + return true + } + } + return false +} + +func extractClaudeToolNames(tools []any) []string { + out := make([]string, 0, len(tools)) + for _, t := range tools { + m, ok := t.(map[string]any) + if !ok { + continue + } + if name, ok := m["name"].(string); ok && name != "" { + out = append(out, name) + } + } + return out +} + +func toMessageMaps(v any) []map[string]any { + arr, ok := v.([]any) + if !ok { + return nil + } + out := make([]map[string]any, 0, len(arr)) + for _, item := range arr { + if m, ok := item.(map[string]any); ok { + out = append(out, m) + } + } + return out +} + +func extractMessageContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, it := range x { + parts = append(parts, fmt.Sprintf("%v", it)) + } + return strings.Join(parts, "\n") + default: + return fmt.Sprintf("%v", x) + } +} + +func cloneMap(in map[string]any) map[string]any { + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func toBool(v any) bool { + b, _ := v.(bool) + return b +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go new file mode 100644 index 0000000..e06c9c8 --- /dev/null +++ b/internal/adapter/openai/handler.go @@ -0,0 +1,413 @@ +package openai + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/sse" + "ds2api/internal/util" +) + +type Handler struct { + Store *config.Store + Auth *auth.Resolver + DS *deepseek.Client +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Get("/v1/models", h.ListModels) + r.Post("/v1/chat/completions", h.ChatCompletions) +} + +func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) +} + +func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeJSON(w, status, map[string]any{"error": detail}) + return + } + defer h.Auth.Release(a) + r = r.WithContext(auth.WithAuth(r.Context(), a)) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"}) + return + } + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if model == "" || len(messagesRaw) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."}) + return + } + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) + if !ok { + writeJSON(w, http.StatusServiceUnavailable, map[string]any{"error": fmt.Sprintf("Model '%s' is not available.", model)}) + return + } + + messages := normalizeMessages(messagesRaw) + toolNames := []string{} + if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { + messages, toolNames = injectToolPrompt(messages, tools) + } + finalPrompt := util.MessagesPrepare(messages) + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"error": "invalid token."}) + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"error": "Failed to get PoW (invalid token or unknown error)."}) + return + } + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": finalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": thinkingEnabled, + "search_enabled": searchEnabled, + } + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"error": "Failed to get completion."}) + return + } + if toBool(req["stream"]) { + h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + return + } + h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) +} + +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeJSON(w, resp.StatusCode, map[string]any{"error": string(body)}) + return + } + thinking := strings.Builder{} + text := strings.Builder{} + currentType := "text" + if thinkingEnabled { + currentType = "thinking" + } + _ = ctx + _ = deepseek.ScanSSELines(resp, func(line []byte) bool { + chunk, done, ok := sse.ParseDeepSeekSSELine(line) + if !ok { + return true + } + if done { + return false + } + if _, hasErr := chunk["error"]; hasErr { + return false + } + parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) + currentType = newType + if finished { + return false + } + for _, p := range parts { + if searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Type == "thinking" { + thinking.WriteString(p.Text) + } else { + text.WriteString(p.Text) + } + } + return true + }) + + finalThinking := thinking.String() + finalText := text.String() + detected := util.ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if thinkingEnabled && finalThinking != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + + writeJSON(w, http.StatusOK, map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + }) +} + +func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeJSON(w, resp.StatusCode, map[string]any{"error": string(body)}) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + flusher, ok := w.(http.Flusher) + if !ok { + writeJSON(w, http.StatusInternalServerError, map[string]any{"error": "streaming unsupported"}) + return + } + + lines := make(chan []byte, 128) + done := make(chan error, 1) + go func() { + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + b := append([]byte{}, scanner.Bytes()...) + lines <- b + } + close(lines) + done <- scanner.Err() + }() + + created := time.Now().Unix() + firstChunkSent := false + currentType := "text" + if thinkingEnabled { + currentType = "thinking" + } + thinking := strings.Builder{} + text := strings.Builder{} + lastContent := time.Now() + hasContent := false + keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second) + defer keepaliveTicker.Stop() + + sendChunk := func(v any) { + b, _ := json.Marshal(v) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(b) + _, _ = w.Write([]byte("\n\n")) + flusher.Flush() + } + sendDone := func() { + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + } + + finalize := func(finishReason string) { + finalThinking := thinking.String() + finalText := text.String() + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) > 0 { + finishReason = "tool_calls" + sendChunk(map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{"delta": map[string]any{"tool_calls": util.FormatOpenAIToolCalls(detected)}, "index": 0}}, + }) + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + sendChunk(map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + }) + sendDone() + } + + for { + select { + case <-r.Context().Done(): + return + case <-keepaliveTicker.C: + if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second { + finalize("stop") + return + } + _, _ = w.Write([]byte(": keep-alive\n\n")) + flusher.Flush() + case line, ok := <-lines: + if !ok { + finalize("stop") + return + } + chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line) + if !parsed { + continue + } + if doneSignal { + finalize("stop") + return + } + if _, hasErr := chunk["error"]; hasErr || chunk["code"] == "content_filter" { + finalize("content_filter") + return + } + parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) + currentType = newType + if finished { + finalize("stop") + return + } + newChoices := make([]map[string]any, 0, len(parts)) + for _, p := range parts { + if searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Text == "" { + continue + } + hasContent = true + lastContent = time.Now() + delta := map[string]any{} + if !firstChunkSent { + delta["role"] = "assistant" + firstChunkSent = true + } + if p.Type == "thinking" { + if thinkingEnabled { + thinking.WriteString(p.Text) + delta["reasoning_content"] = p.Text + } + } else { + text.WriteString(p.Text) + delta["content"] = p.Text + } + if len(delta) > 0 { + newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0}) + } + } + if len(newChoices) > 0 { + sendChunk(map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": newChoices, + }) + } + case <-done: + finalize("stop") + return + } + } +} + +func normalizeMessages(raw []any) []map[string]any { + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + m, ok := item.(map[string]any) + if ok { + out = append(out, m) + } + } + return out +} + +func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { + toolSchemas := make([]string, 0, len(tools)) + names := make([]string, 0, len(tools)) + for _, t := range tools { + tool, ok := t.(map[string]any) + if !ok { + continue + } + fn, _ := tool["function"].(map[string]any) + if len(fn) == 0 { + fn = tool + } + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + schema, _ := fn["parameters"].(map[string]any) + if name == "" { + name = "unknown" + } + names = append(names, name) + if desc == "" { + desc = "No description available" + } + b, _ := json.Marshal(schema) + toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b))) + } + if len(toolSchemas) == 0 { + return messages, names + } + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }" + + for i := range messages { + if messages[i]["role"] == "system" { + old, _ := messages[i]["content"].(string) + messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt) + return messages, names + } + } + messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...) + return messages, names +} + +func toBool(v any) bool { + if b, ok := v.(bool); ok { + return b + } + return false +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} diff --git a/internal/admin/handler.go b/internal/admin/handler.go new file mode 100644 index 0000000..6985fba --- /dev/null +++ b/internal/admin/handler.go @@ -0,0 +1,890 @@ +package admin + +import ( + "bufio" + "bytes" + "context" + "crypto/md5" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + authn "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/sse" +) + +type Handler struct { + Store *config.Store + Pool *account.Pool + DS *deepseek.Client +} + +func RegisterRoutes(r chi.Router, h *Handler) { + + r.Post("/login", h.login) + r.Get("/verify", h.verify) + r.Group(func(pr chi.Router) { + pr.Use(h.requireAdmin) + pr.Get("/vercel/config", h.getVercelConfig) + pr.Get("/config", h.getConfig) + pr.Post("/config", h.updateConfig) + pr.Post("/keys", h.addKey) + pr.Delete("/keys/{key}", h.deleteKey) + pr.Get("/accounts", h.listAccounts) + pr.Post("/accounts", h.addAccount) + pr.Delete("/accounts/{identifier}", h.deleteAccount) + pr.Get("/queue/status", h.queueStatus) + pr.Post("/accounts/test", h.testSingleAccount) + pr.Post("/accounts/test-all", h.testAllAccounts) + pr.Post("/import", h.batchImport) + pr.Post("/test", h.testAPI) + pr.Post("/vercel/sync", h.syncVercel) + pr.Get("/vercel/status", h.vercelStatus) + pr.Get("/export", h.exportConfig) + }) +} + +func (h *Handler) requireAdmin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := authn.VerifyAdminRequest(r); err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) + return + } + next.ServeHTTP(w, r) + }) +} + +func (h *Handler) login(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + adminKey, _ := req["admin_key"].(string) + expireHours := intFrom(req["expire_hours"]) + if expireHours <= 0 { + expireHours = 24 + } + if adminKey != authn.AdminKey() { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"}) + return + } + token, err := authn.CreateJWT(expireHours) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600}) +} + +func (h *Handler) verify(w http.ResponseWriter, r *http.Request) { + header := strings.TrimSpace(r.Header.Get("Authorization")) + if !strings.HasPrefix(strings.ToLower(header), "bearer ") { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "No credentials provided"}) + return + } + token := strings.TrimSpace(header[7:]) + payload, err := authn.VerifyJWT(token) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) + return + } + exp, _ := payload["exp"].(float64) + remaining := int64(exp) - time.Now().Unix() + if remaining < 0 { + remaining = 0 + } + writeJSON(w, http.StatusOK, map[string]any{"valid": true, "expires_at": int64(exp), "remaining_seconds": remaining}) +} + +func (h *Handler) getVercelConfig(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]any{ + "has_token": strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) != "", + "project_id": strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")), + "team_id": nilIfEmpty(strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID"))), + }) +} + +func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + safe := map[string]any{ + "keys": snap.Keys, + "accounts": []map[string]any{}, + "claude_mapping": func() map[string]string { + if len(snap.ClaudeMapping) > 0 { + return snap.ClaudeMapping + } + return snap.ClaudeModelMap + }(), + } + accounts := make([]map[string]any, 0, len(snap.Accounts)) + for _, acc := range snap.Accounts { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + accounts = append(accounts, map[string]any{ + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": strings.TrimSpace(acc.Password) != "", + "has_token": token != "", + "token_preview": preview, + }) + } + safe["accounts"] = accounts + writeJSON(w, http.StatusOK, safe) +} + +func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + old := h.Store.Snapshot() + err := h.Store.Update(func(c *config.Config) error { + if keys, ok := toStringSlice(req["keys"]); ok { + c.Keys = keys + } + if accountsRaw, ok := req["accounts"].([]any); ok { + existing := map[string]config.Account{} + for _, a := range old.Accounts { + existing[a.Identifier()] = a + } + accounts := make([]config.Account, 0, len(accountsRaw)) + for _, item := range accountsRaw { + m, ok := item.(map[string]any) + if !ok { + continue + } + acc := toAccount(m) + id := acc.Identifier() + if prev, ok := existing[id]; ok { + if strings.TrimSpace(acc.Password) == "" { + acc.Password = prev.Password + } + if strings.TrimSpace(acc.Token) == "" { + acc.Token = prev.Token + } + } + accounts = append(accounts, acc) + } + c.Accounts = accounts + } + if m, ok := req["claude_mapping"].(map[string]any); ok { + newMap := map[string]string{} + for k, v := range m { + newMap[k] = fmt.Sprintf("%v", v) + } + c.ClaudeMapping = newMap + } + return nil + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"}) +} + +func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + key, _ := req["key"].(string) + key = strings.TrimSpace(key) + if key == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, k := range c.Keys { + if k == key { + return fmt.Errorf("Key 已存在") + } + } + c.Keys = append(c.Keys, key) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) +} + +func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) { + key := chi.URLParam(r, "key") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, k := range c.Keys { + if k == key { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("Key 不存在") + } + c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) +} + +func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { + page := intFromQuery(r, "page", 1) + pageSize := intFromQuery(r, "page_size", 10) + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 1 + } + if pageSize > 100 { + pageSize = 100 + } + accounts := h.Store.Snapshot().Accounts + total := len(accounts) + reverseAccounts(accounts) + totalPages := 1 + if total > 0 { + totalPages = (total + pageSize - 1) / pageSize + } + start := (page - 1) * pageSize + if start > total { + start = total + } + end := start + pageSize + if end > total { + end = total + } + items := make([]map[string]any, 0, end-start) + for _, acc := range accounts[start:end] { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview}) + } + writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) +} + +func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + acc := toAccount(req) + if acc.Identifier() == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, a := range c.Accounts { + if acc.Email != "" && a.Email == acc.Email { + return fmt.Errorf("邮箱已存在") + } + if acc.Mobile != "" && a.Mobile == acc.Mobile { + return fmt.Errorf("手机号已存在") + } + } + c.Accounts = append(c.Accounts, acc) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { + identifier := chi.URLParam(r, "identifier") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, a := range c.Accounts { + if a.Email == identifier || a.Mobile == identifier { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("账号不存在") + } + c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, h.Pool.Status()) +} + +func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + identifier, _ := req["identifier"].(string) + if strings.TrimSpace(identifier) == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"}) + return + } + acc, ok := h.Store.FindAccount(identifier) + if !ok { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"}) + return + } + model, _ := req["model"].(string) + if model == "" { + model = "deepseek-chat" + } + message, _ := req["message"].(string) + result := h.testAccount(r.Context(), acc, model, message) + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + model, _ := req["model"].(string) + if model == "" { + model = "deepseek-chat" + } + accounts := h.Store.Snapshot().Accounts + if len(accounts) == 0 { + writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}}) + return + } + results := make([]map[string]any, 0, len(accounts)) + success := 0 + for _, acc := range accounts { + res := h.testAccount(r.Context(), acc, model, "") + if ok, _ := res["success"].(bool); ok { + success++ + } + results = append(results, res) + time.Sleep(time.Second) + } + writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results}) +} + +func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any { + start := time.Now() + result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model} + token := strings.TrimSpace(acc.Token) + if token == "" { + newToken, err := h.DS.Login(ctx, acc) + if err != nil { + result["message"] = "登录失败: " + err.Error() + return result + } + token = newToken + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token} + sessionID, err := h.DS.CreateSession(ctx, authCtx, 1) + if err != nil { + newToken, loginErr := h.DS.Login(ctx, acc) + if loginErr != nil { + result["message"] = "创建会话失败: " + err.Error() + return result + } + token = newToken + authCtx.DeepSeekToken = token + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + sessionID, err = h.DS.CreateSession(ctx, authCtx, 1) + if err != nil { + result["message"] = "创建会话失败: " + err.Error() + return result + } + } + if strings.TrimSpace(message) == "" { + result["success"] = true + result["message"] = "API 测试成功(仅会话创建)" + result["response_time"] = int(time.Since(start).Milliseconds()) + return result + } + thinking, search, ok := config.GetModelConfig(model) + if !ok { + thinking, search = false, false + } + pow, err := h.DS.GetPow(ctx, authCtx, 1) + if err != nil { + result["message"] = "获取 PoW 失败: " + err.Error() + return result + } + payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search} + resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1) + if err != nil { + result["message"] = "请求失败: " + err.Error() + return result + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode) + return result + } + text := strings.Builder{} + think := strings.Builder{} + currentType := "text" + if thinking { + currentType = "thinking" + } + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + chunk, done, parsed := sse.ParseDeepSeekSSELine(scanner.Bytes()) + if !parsed { + continue + } + if done { + break + } + parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinking, currentType) + currentType = newType + if finished { + break + } + for _, p := range parts { + if p.Type == "thinking" { + think.WriteString(p.Text) + } else { + text.WriteString(p.Text) + } + } + } + result["success"] = true + result["response_time"] = int(time.Since(start).Milliseconds()) + if text.Len() > 0 { + result["message"] = text.String() + } else { + result["message"] = "(无回复内容)" + } + if think.Len() > 0 { + result["thinking"] = think.String() + } + return result +} + +func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"}) + return + } + importedKeys, importedAccounts := 0, 0 + err := h.Store.Update(func(c *config.Config) error { + if keys, ok := req["keys"].([]any); ok { + existing := map[string]bool{} + for _, k := range c.Keys { + existing[k] = true + } + for _, k := range keys { + key := strings.TrimSpace(fmt.Sprintf("%v", k)) + if key == "" || existing[key] { + continue + } + c.Keys = append(c.Keys, key) + existing[key] = true + importedKeys++ + } + } + if accounts, ok := req["accounts"].([]any); ok { + existing := map[string]bool{} + for _, a := range c.Accounts { + existing[a.Identifier()] = true + } + for _, item := range accounts { + m, ok := item.(map[string]any) + if !ok { + continue + } + acc := toAccount(m) + id := acc.Identifier() + if id == "" || existing[id] { + continue + } + c.Accounts = append(c.Accounts, acc) + existing[id] = true + importedAccounts++ + } + } + return nil + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts}) +} + +func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + model, _ := req["model"].(string) + message, _ := req["message"].(string) + apiKey, _ := req["api_key"].(string) + if model == "" { + model = "deepseek-chat" + } + if message == "" { + message = "你好" + } + if apiKey == "" { + keys := h.Store.Snapshot().Keys + if len(keys) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"}) + return + } + apiKey = keys[0] + } + host := r.Host + scheme := "http" + if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") { + scheme = "https" + } + payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false} + b, _ := json.Marshal(payload) + request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b)) + request.Header.Set("Authorization", "Bearer "+apiKey) + request.Header.Set("Content-Type", "application/json") + resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request) + if err != nil { + writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()}) + return + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusOK { + var parsed any + _ = json.Unmarshal(body, &parsed) + writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)}) +} + +func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + vercelToken, _ := req["vercel_token"].(string) + projectID, _ := req["project_id"].(string) + teamID, _ := req["team_id"].(string) + autoValidate := true + if v, ok := req["auto_validate"].(bool); ok { + autoValidate = v + } + saveCreds := true + if v, ok := req["save_credentials"].(bool); ok { + saveCreds = v + } + usePreconfig := vercelToken == "__USE_PRECONFIG__" || strings.TrimSpace(vercelToken) == "" + if usePreconfig { + vercelToken = strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) + } + if strings.TrimSpace(projectID) == "" { + projectID = strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")) + } + if strings.TrimSpace(teamID) == "" { + teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID")) + } + if vercelToken == "" || projectID == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"}) + return + } + validated, failed := 0, []string{} + if autoValidate { + for _, acc := range h.Store.Snapshot().Accounts { + if strings.TrimSpace(acc.Token) != "" { + continue + } + token, err := h.DS.Login(r.Context(), acc) + if err != nil { + failed = append(failed, acc.Identifier()) + } else { + validated++ + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + time.Sleep(500 * time.Millisecond) + } + } + + cfgJSON, _, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON)) + client := &http.Client{Timeout: 30 * time.Second} + params := url.Values{} + if teamID != "" { + params.Set("teamId", teamID) + } + headers := map[string]string{"Authorization": "Bearer " + vercelToken} + envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil) + if err != nil || status != http.StatusOK { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) + return + } + envs, _ := envResp["envs"].([]any) + existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON") + if existingEnvID != "" { + _, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64}) + } else { + _, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}}) + } + if err != nil || (status != http.StatusOK && status != http.StatusCreated) { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) + return + } + savedCreds := []string{} + if saveCreds && !usePreconfig { + creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}} + if teamID != "" { + creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID}) + } + for _, kv := range creds { + id := findEnvID(envs, kv[0]) + if id != "" { + _, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]}) + } else { + _, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}}) + } + if status == http.StatusOK || status == http.StatusCreated { + savedCreds = append(savedCreds, kv[0]) + } + } + } + projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) + manual := true + deployURL := "" + if status == http.StatusOK { + if link, ok := projectResp["link"].(map[string]any); ok { + if linkType, _ := link["type"].(string); linkType == "github" { + repoID := intFrom(link["repoId"]) + ref, _ := link["productionBranch"].(string) + if ref == "" { + ref = "main" + } + depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}}) + if depStatus == http.StatusOK || depStatus == http.StatusCreated { + deployURL, _ = depResp["url"].(string) + manual = false + } + } + } + } + _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) + result := map[string]any{"success": true, "validated_accounts": validated} + if manual { + result["message"] = "配置已同步到 Vercel,请手动触发重新部署" + result["manual_deploy_required"] = true + } else { + result["message"] = "配置已同步,正在重新部署..." + result["deployment_url"] = deployURL + } + if len(failed) > 0 { + result["failed_accounts"] = failed + } + if len(savedCreds) > 0 { + result["saved_credentials"] = savedCreds + } + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + current := h.computeSyncHash() + synced := snap.VercelSyncHash != "" && snap.VercelSyncHash == current + writeJSON(w, http.StatusOK, map[string]any{"synced": synced, "last_sync_time": nilIfZero(snap.VercelSyncTime), "has_synced_before": snap.VercelSyncHash != ""}) +} + +func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { + jsonStr, b64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64}) +} + +func (h *Handler) computeSyncHash() string { + snap := h.Store.Snapshot() + syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}} + accounts := make([]map[string]any, 0, len(snap.Accounts)) + for _, a := range snap.Accounts { + m := map[string]any{} + if a.Email != "" { + m["email"] = a.Email + } + if a.Mobile != "" { + m["mobile"] = a.Mobile + } + if a.Password != "" { + m["password"] = a.Password + } + accounts = append(accounts, m) + } + sort.Slice(accounts, func(i, j int) bool { + ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"]) + aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"]) + return ai < aj + }) + syncable["accounts"] = accounts + b, _ := json.Marshal(syncable) + sum := md5.Sum(b) + return fmt.Sprintf("%x", sum) +} + +func vercelRequest(ctx context.Context, client *http.Client, method, endpoint string, params url.Values, headers map[string]string, body any) (map[string]any, int, error) { + if len(params) > 0 { + endpoint += "?" + params.Encode() + } + var reader io.Reader + if body != nil { + b, _ := json.Marshal(body) + reader = bytes.NewReader(b) + } + req, err := http.NewRequestWithContext(ctx, method, endpoint, reader) + if err != nil { + return nil, 0, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + parsed := map[string]any{} + _ = json.Unmarshal(b, &parsed) + if len(parsed) == 0 { + parsed["raw"] = string(b) + } + return parsed, resp.StatusCode, nil +} + +func findEnvID(envs []any, key string) string { + for _, item := range envs { + m, ok := item.(map[string]any) + if !ok { + continue + } + if k, _ := m["key"].(string); k == key { + id, _ := m["id"].(string) + return id + } + } + return "" +} + +func reverseAccounts(a []config.Account) { + for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { + a[i], a[j] = a[j], a[i] + } +} + +func intFromQuery(r *http.Request, key string, d int) int { + v := r.URL.Query().Get(key) + if v == "" { + return d + } + n, err := strconv.Atoi(v) + if err != nil { + return d + } + return n +} + +func intFrom(v any) int { + switch n := v.(type) { + case float64: + return int(n) + case int: + return n + case int64: + return int(n) + default: + return 0 + } +} + +func nilIfEmpty(s string) any { + if s == "" { + return nil + } + return s +} + +func nilIfZero(v int64) any { + if v == 0 { + return nil + } + return v +} + +func toStringSlice(v any) ([]string, bool) { + arr, ok := v.([]any) + if !ok { + return nil, false + } + out := make([]string, 0, len(arr)) + for _, item := range arr { + out = append(out, strings.TrimSpace(fmt.Sprintf("%v", item))) + } + return out, true +} + +func toAccount(m map[string]any) config.Account { + return config.Account{Email: strings.TrimSpace(fmt.Sprintf("%v", m["email"])), Mobile: strings.TrimSpace(fmt.Sprintf("%v", m["mobile"])), Password: strings.TrimSpace(fmt.Sprintf("%v", m["password"])), Token: strings.TrimSpace(fmt.Sprintf("%v", m["token"]))} +} + +func statusOr(v int, d int) int { + if v == 0 { + return d + } + return v +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} diff --git a/internal/auth/admin.go b/internal/auth/admin.go new file mode 100644 index 0000000..739e79f --- /dev/null +++ b/internal/auth/admin.go @@ -0,0 +1,113 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +func AdminKey() string { + if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" { + return v + } + return "your-admin-secret-key" +} + +func jwtSecret() string { + if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" { + return v + } + return AdminKey() +} + +func jwtExpireHours() int { + if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return 24 +} + +func CreateJWT(expireHours int) (string, error) { + if expireHours <= 0 { + expireHours = jwtExpireHours() + } + header := map[string]any{"alg": "HS256", "typ": "JWT"} + payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"} + h, _ := json.Marshal(header) + p, _ := json.Marshal(payload) + headerB64 := rawB64Encode(h) + payloadB64 := rawB64Encode(p) + msg := headerB64 + "." + payloadB64 + sig := signHS256(msg) + return msg + "." + rawB64Encode(sig), nil +} + +func VerifyJWT(token string) (map[string]any, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, errors.New("invalid token format") + } + msg := parts[0] + "." + parts[1] + expected := signHS256(msg) + actual, err := rawB64Decode(parts[2]) + if err != nil { + return nil, errors.New("invalid signature") + } + if !hmac.Equal(expected, actual) { + return nil, errors.New("invalid signature") + } + payloadBytes, err := rawB64Decode(parts[1]) + if err != nil { + return nil, errors.New("invalid payload") + } + var payload map[string]any + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return nil, errors.New("invalid payload") + } + exp, _ := payload["exp"].(float64) + if int64(exp) < time.Now().Unix() { + return nil, errors.New("token expired") + } + return payload, nil +} + +func VerifyAdminRequest(r *http.Request) error { + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return errors.New("authentication required") + } + token := strings.TrimSpace(authHeader[7:]) + if token == "" { + return errors.New("authentication required") + } + if token == AdminKey() { + return nil + } + if _, err := VerifyJWT(token); err == nil { + return nil + } + return errors.New("invalid credentials") +} + +func signHS256(msg string) []byte { + h := hmac.New(sha256.New, []byte(jwtSecret())) + _, _ = h.Write([]byte(msg)) + return h.Sum(nil) +} + +func rawB64Encode(b []byte) string { + return base64.RawURLEncoding.EncodeToString(b) +} + +func rawB64Decode(s string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(s) +} diff --git a/internal/auth/admin_test.go b/internal/auth/admin_test.go new file mode 100644 index 0000000..7489074 --- /dev/null +++ b/internal/auth/admin_test.go @@ -0,0 +1,29 @@ +package auth + +import ( + "net/http" + "testing" +) + +func TestJWTCreateVerify(t *testing.T) { + token, err := CreateJWT(1) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + payload, err := VerifyJWT(token) + if err != nil { + t.Fatalf("verify jwt failed: %v", err) + } + if payload["role"] != "admin" { + t.Fatalf("unexpected payload: %#v", payload) + } +} + +func TestVerifyAdminRequest(t *testing.T) { + token, _ := CreateJWT(1) + req, _ := http.NewRequest(http.MethodGet, "/admin/config", nil) + req.Header.Set("Authorization", "Bearer "+token) + if err := VerifyAdminRequest(req); err != nil { + t.Fatalf("expected token accepted: %v", err) + } +} diff --git a/internal/auth/request.go b/internal/auth/request.go new file mode 100644 index 0000000..819665e --- /dev/null +++ b/internal/auth/request.go @@ -0,0 +1,150 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "strings" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +type ctxKey string + +const authCtxKey ctxKey = "auth_context" + +var ( + ErrUnauthorized = errors.New("unauthorized: missing Bearer token") + ErrNoAccount = errors.New("no accounts configured or all accounts are busy") +) + +type RequestAuth struct { + UseConfigToken bool + DeepSeekToken string + AccountID string + Account config.Account + TriedAccounts map[string]bool + resolver *Resolver +} + +type LoginFunc func(ctx context.Context, acc config.Account) (string, error) + +type Resolver struct { + Store *config.Store + Pool *account.Pool + Login LoginFunc +} + +func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver { + return &Resolver{Store: store, Pool: pool, Login: login} +} + +func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { + authHeader := req.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, "Bearer ") { + return nil, ErrUnauthorized + } + callerKey := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + ctx := req.Context() + if !r.Store.HasAPIKey(callerKey) { + return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil + } + target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) + acc, ok := r.Pool.Acquire(target, nil) + if !ok { + return nil, ErrNoAccount + } + a := &RequestAuth{ + UseConfigToken: true, + AccountID: acc.Identifier(), + Account: acc, + TriedAccounts: map[string]bool{}, + resolver: r, + } + if acc.Token == "" { + if err := r.loginAndPersist(ctx, a); err != nil { + r.Pool.Release(a.AccountID) + return nil, err + } + } else { + a.DeepSeekToken = acc.Token + } + return a, nil +} + +func WithAuth(ctx context.Context, a *RequestAuth) context.Context { + return context.WithValue(ctx, authCtxKey, a) +} + +func FromContext(ctx context.Context) (*RequestAuth, bool) { + v := ctx.Value(authCtxKey) + a, ok := v.(*RequestAuth) + return a, ok +} + +func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error { + token, err := r.Login(ctx, a.Account) + if err != nil { + return err + } + a.Account.Token = token + a.DeepSeekToken = token + return r.Store.UpdateAccountToken(a.AccountID, token) +} + +func (r *Resolver) RefreshToken(ctx context.Context, a *RequestAuth) bool { + if !a.UseConfigToken || a.AccountID == "" { + return false + } + _ = r.Store.UpdateAccountToken(a.AccountID, "") + a.Account.Token = "" + if err := r.loginAndPersist(ctx, a); err != nil { + config.Logger.Error("[refresh_token] failed", "account", a.AccountID, "error", err) + return false + } + return true +} + +func (r *Resolver) MarkTokenInvalid(a *RequestAuth) { + if !a.UseConfigToken || a.AccountID == "" { + return + } + a.Account.Token = "" + a.DeepSeekToken = "" + _ = r.Store.UpdateAccountToken(a.AccountID, "") +} + +func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool { + if !a.UseConfigToken { + return false + } + if a.TriedAccounts == nil { + a.TriedAccounts = map[string]bool{} + } + if a.AccountID != "" { + a.TriedAccounts[a.AccountID] = true + r.Pool.Release(a.AccountID) + } + acc, ok := r.Pool.Acquire("", a.TriedAccounts) + if !ok { + return false + } + a.Account = acc + a.AccountID = acc.Identifier() + if acc.Token == "" { + if err := r.loginAndPersist(ctx, a); err != nil { + return false + } + } else { + a.DeepSeekToken = acc.Token + } + return true +} + +func (r *Resolver) Release(a *RequestAuth) { + if a == nil || !a.UseConfigToken || a.AccountID == "" { + return + } + r.Pool.Release(a.AccountID) +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..af2fc6a --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,360 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "errors" + "log/slog" + "os" + "path/filepath" + "slices" + "strings" + "sync" +) + +var Logger = newLogger() + +func newLogger() *slog.Logger { + level := new(slog.LevelVar) + switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) { + case "DEBUG": + level.Set(slog.LevelDebug) + case "WARN": + level.Set(slog.LevelWarn) + case "ERROR": + level.Set(slog.LevelError) + default: + level.Set(slog.LevelInfo) + } + h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}) + return slog.New(h) +} + +type Account struct { + Email string `json:"email,omitempty"` + Mobile string `json:"mobile,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` +} + +func (a Account) Identifier() string { + if strings.TrimSpace(a.Email) != "" { + return strings.TrimSpace(a.Email) + } + return strings.TrimSpace(a.Mobile) +} + +type Config struct { + Keys []string `json:"keys,omitempty"` + Accounts []Account `json:"accounts,omitempty"` + ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` + ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` + VercelSyncHash string `json:"_vercel_sync_hash,omitempty"` + VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"` + AdditionalFields map[string]any `json:"-"` +} + +func (c Config) MarshalJSON() ([]byte, error) { + m := map[string]any{} + for k, v := range c.AdditionalFields { + m[k] = v + } + if len(c.Keys) > 0 { + m["keys"] = c.Keys + } + if len(c.Accounts) > 0 { + m["accounts"] = c.Accounts + } + if len(c.ClaudeMapping) > 0 { + m["claude_mapping"] = c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + m["claude_model_mapping"] = c.ClaudeModelMap + } + if c.VercelSyncHash != "" { + m["_vercel_sync_hash"] = c.VercelSyncHash + } + if c.VercelSyncTime != 0 { + m["_vercel_sync_time"] = c.VercelSyncTime + } + return json.Marshal(m) +} + +func (c *Config) UnmarshalJSON(b []byte) error { + raw := map[string]json.RawMessage{} + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + c.AdditionalFields = map[string]any{} + for k, v := range raw { + switch k { + case "keys": + _ = json.Unmarshal(v, &c.Keys) + case "accounts": + _ = json.Unmarshal(v, &c.Accounts) + case "claude_mapping": + _ = json.Unmarshal(v, &c.ClaudeMapping) + case "claude_model_mapping": + _ = json.Unmarshal(v, &c.ClaudeModelMap) + case "_vercel_sync_hash": + _ = json.Unmarshal(v, &c.VercelSyncHash) + case "_vercel_sync_time": + _ = json.Unmarshal(v, &c.VercelSyncTime) + default: + var anyVal any + if err := json.Unmarshal(v, &anyVal); err == nil { + c.AdditionalFields[k] = anyVal + } + } + } + return nil +} + +func (c Config) Clone() Config { + clone := Config{ + Keys: slices.Clone(c.Keys), + Accounts: slices.Clone(c.Accounts), + ClaudeMapping: cloneStringMap(c.ClaudeMapping), + ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + VercelSyncHash: c.VercelSyncHash, + VercelSyncTime: c.VercelSyncTime, + AdditionalFields: map[string]any{}, + } + for k, v := range c.AdditionalFields { + clone.AdditionalFields[k] = v + } + return clone +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +type Store struct { + mu sync.RWMutex + cfg Config + path string + fromEnv bool +} + +func BaseDir() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +func IsVercel() bool { + return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" +} + +func ResolvePath(envKey, defaultRel string) string { + raw := strings.TrimSpace(os.Getenv(envKey)) + if raw != "" { + if filepath.IsAbs(raw) { + return raw + } + return filepath.Join(BaseDir(), raw) + } + return filepath.Join(BaseDir(), defaultRel) +} + +func ConfigPath() string { + return ResolvePath("DS2API_CONFIG_PATH", "config.json") +} + +func WASMPath() string { + return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") +} + +func StaticAdminDir() string { + return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin") +} + +func LoadStore() *Store { + cfg, fromEnv, err := loadConfig() + if err != nil { + Logger.Warn("[config] load failed", "error", err) + } + if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { + Logger.Warn("[config] empty config loaded") + } + return &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} +} + +func loadConfig() (Config, bool, error) { + rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON")) + if rawCfg == "" { + rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON")) + } + if rawCfg != "" { + cfg, err := parseConfigString(rawCfg) + return cfg, true, err + } + + content, err := os.ReadFile(ConfigPath()) + if err != nil { + return Config{}, false, err + } + var cfg Config + if err := json.Unmarshal(content, &cfg); err != nil { + return Config{}, false, err + } + return cfg, false, nil +} + +func parseConfigString(raw string) (Config, error) { + var cfg Config + if err := json.Unmarshal([]byte(raw), &cfg); err == nil { + return cfg, nil + } + decoded, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return Config{}, err + } + if err := json.Unmarshal(decoded, &cfg); err != nil { + return Config{}, err + } + return cfg, nil +} + +func (s *Store) Snapshot() Config { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Clone() +} + +func (s *Store) HasAPIKey(k string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + for _, key := range s.cfg.Keys { + if key == k { + return true + } + } + return false +} + +func (s *Store) Keys() []string { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.cfg.Keys) +} + +func (s *Store) Accounts() []Account { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.cfg.Accounts) +} + +func (s *Store) FindAccount(identifier string) (Account, bool) { + identifier = strings.TrimSpace(identifier) + s.mu.RLock() + defer s.mu.RUnlock() + for _, acc := range s.cfg.Accounts { + if acc.Identifier() == identifier { + return acc, true + } + } + return Account{}, false +} + +func (s *Store) UpdateAccountToken(identifier, token string) error { + s.mu.Lock() + defer s.mu.Unlock() + for i := range s.cfg.Accounts { + if s.cfg.Accounts[i].Identifier() == identifier { + s.cfg.Accounts[i].Token = token + return s.saveLocked() + } + } + return errors.New("account not found") +} + +func (s *Store) Replace(cfg Config) error { + s.mu.Lock() + defer s.mu.Unlock() + s.cfg = cfg.Clone() + return s.saveLocked() +} + +func (s *Store) Update(mutator func(*Config) error) error { + s.mu.Lock() + defer s.mu.Unlock() + cfg := s.cfg.Clone() + if err := mutator(&cfg); err != nil { + return err + } + s.cfg = cfg + return s.saveLocked() +} + +func (s *Store) Save() error { + s.mu.RLock() + defer s.mu.RUnlock() + if s.fromEnv { + Logger.Info("[save_config] source from env, skip write") + return nil + } + b, err := json.MarshalIndent(s.cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o644) +} + +func (s *Store) saveLocked() error { + if s.fromEnv { + Logger.Info("[save_config] source from env, skip write") + return nil + } + b, err := json.MarshalIndent(s.cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o644) +} + +func (s *Store) IsEnvBacked() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.fromEnv +} + +func (s *Store) SetVercelSync(hash string, ts int64) error { + return s.Update(func(c *Config) error { + c.VercelSyncHash = hash + c.VercelSyncTime = ts + return nil + }) +} + +func (s *Store) ExportJSONAndBase64() (string, string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + b, err := json.Marshal(s.cfg) + if err != nil { + return "", "", err + } + return string(b), base64.StdEncoding.EncodeToString(b), nil +} + +func (s *Store) ClaudeMapping() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + if len(s.cfg.ClaudeModelMap) > 0 { + return cloneStringMap(s.cfg.ClaudeModelMap) + } + if len(s.cfg.ClaudeMapping) > 0 { + return cloneStringMap(s.cfg.ClaudeMapping) + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-chat"} +} diff --git a/internal/config/models.go b/internal/config/models.go new file mode 100644 index 0000000..c6cbd5f --- /dev/null +++ b/internal/config/models.go @@ -0,0 +1,55 @@ +package config + +type ModelInfo struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []any `json:"permission,omitempty"` +} + +var DeepSeekModels = []ModelInfo{ + {ID: "deepseek-chat", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, + {ID: "deepseek-reasoner", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, + {ID: "deepseek-chat-search", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, + {ID: "deepseek-reasoner-search", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, +} + +var ClaudeModels = []ModelInfo{ + {ID: "claude-sonnet-4-20250514", Object: "model", Created: 1715635200, OwnedBy: "anthropic"}, + {ID: "claude-sonnet-4-20250514-fast", Object: "model", Created: 1715635200, OwnedBy: "anthropic"}, + {ID: "claude-sonnet-4-20250514-slow", Object: "model", Created: 1715635200, OwnedBy: "anthropic"}, +} + +func GetModelConfig(model string) (thinking bool, search bool, ok bool) { + switch lower(model) { + case "deepseek-chat": + return false, false, true + case "deepseek-reasoner": + return true, false, true + case "deepseek-chat-search": + return false, true, true + case "deepseek-reasoner-search": + return true, true, true + default: + return false, false, false + } +} + +func lower(s string) string { + b := []byte(s) + for i, c := range b { + if c >= 'A' && c <= 'Z' { + b[i] = c + 32 + } + } + return string(b) +} + +func OpenAIModelsResponse() map[string]any { + return map[string]any{"object": "list", "data": DeepSeekModels} +} + +func ClaudeModelsResponse() map[string]any { + return map[string]any{"object": "list", "data": ClaudeModels} +} diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go new file mode 100644 index 0000000..be596df --- /dev/null +++ b/internal/deepseek/client.go @@ -0,0 +1,342 @@ +package deepseek + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + trans "ds2api/internal/deepseek/transport" + + "github.com/andybalholm/brotli" +) + +type Client struct { + Store *config.Store + Auth *auth.Resolver + regular trans.Doer + stream trans.Doer + fallback *http.Client + fallbackS *http.Client + powSolver *PowSolver + maxRetries int +} + +func NewClient(store *config.Store, resolver *auth.Resolver) *Client { + return &Client{ + Store: store, + Auth: resolver, + regular: trans.New(60 * time.Second), + stream: trans.New(0), + fallback: &http.Client{Timeout: 60 * time.Second}, + fallbackS: &http.Client{Timeout: 0}, + powSolver: NewPowSolver(config.WASMPath()), + maxRetries: 3, + } +} + +func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) { + payload := map[string]any{ + "password": strings.TrimSpace(acc.Password), + "device_id": "deepseek_to_api", + "os": "android", + } + if email := strings.TrimSpace(acc.Email); email != "" { + payload["email"] = email + } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" { + payload["mobile"] = mobile + payload["area_code"] = nil + } else { + return "", errors.New("missing email/mobile") + } + resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload) + if err != nil { + return "", err + } + code := intFrom(resp["code"]) + if code != 0 { + return "", fmt.Errorf("login failed: %v", resp["msg"]) + } + data, _ := resp["data"].(map[string]any) + if intFrom(data["biz_code"]) != 0 { + return "", fmt.Errorf("login failed: %v", data["biz_msg"]) + } + bizData, _ := data["biz_data"].(map[string]any) + user, _ := bizData["user"].(map[string]any) + token, _ := user["token"].(string) + if strings.TrimSpace(token) == "" { + return "", errors.New("missing login token") + } + return token, nil +} + +func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + attempts := 0 + refreshed := false + for attempts < maxAttempts { + headers := c.authHeaders(a.DeepSeekToken) + resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"}) + if err != nil { + config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID) + attempts++ + continue + } + code := intFrom(resp["code"]) + if status == http.StatusOK && code == 0 { + data, _ := resp["data"].(map[string]any) + bizData, _ := data["biz_data"].(map[string]any) + sessionID, _ := bizData["id"].(string) + if sessionID != "" { + return sessionID, nil + } + } + msg, _ := resp["msg"].(string) + config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) + if a.UseConfigToken { + if isTokenInvalid(status, code, msg) && !refreshed { + if c.Auth.RefreshToken(ctx, a) { + refreshed = true + continue + } + } + if c.Auth.SwitchAccount(ctx, a) { + refreshed = false + attempts++ + continue + } + } + attempts++ + } + return "", errors.New("create session failed") +} + +func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + attempts := 0 + for attempts < maxAttempts { + headers := c.authHeaders(a.DeepSeekToken) + resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"}) + if err != nil { + config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID) + attempts++ + continue + } + code := intFrom(resp["code"]) + if status == http.StatusOK && code == 0 { + data, _ := resp["data"].(map[string]any) + bizData, _ := data["biz_data"].(map[string]any) + challenge, _ := bizData["challenge"].(map[string]any) + answer, err := c.powSolver.Compute(ctx, challenge) + if err != nil { + attempts++ + continue + } + return BuildPowHeader(challenge, answer) + } + msg, _ := resp["msg"].(string) + config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) + if a.UseConfigToken { + if isTokenInvalid(status, code, msg) { + if c.Auth.RefreshToken(ctx, a) { + continue + } + } + if c.Auth.SwitchAccount(ctx, a) { + attempts++ + continue + } + } + attempts++ + } + return "", errors.New("get pow failed") +} + +func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + headers := c.authHeaders(a.DeepSeekToken) + headers["x-ds-pow-response"] = powResp + attempts := 0 + for attempts < maxAttempts { + resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload) + if err != nil { + attempts++ + time.Sleep(time.Second) + continue + } + if resp.StatusCode == http.StatusOK { + return resp, nil + } + _ = resp.Body.Close() + attempts++ + time.Sleep(time.Second) + } + return nil, errors.New("completion failed") +} + +func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) { + body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload) + if err != nil { + return nil, err + } + if status == 0 { + return nil, errors.New("request failed") + } + return body, nil +} + +func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) { + b, err := json.Marshal(payload) + if err != nil { + return nil, 0, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if err != nil { + return nil, 0, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := doer.Do(req) + if err != nil { + config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err) + req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if reqErr != nil { + return nil, 0, err + } + for k, v := range headers { + req2.Header.Set(k, v) + } + resp, err = c.fallback.Do(req2) + if err != nil { + return nil, 0, err + } + } + defer resp.Body.Close() + payloadBytes, err := readResponseBody(resp) + if err != nil { + return nil, resp.StatusCode, err + } + out := map[string]any{} + if len(payloadBytes) > 0 { + if err := json.Unmarshal(payloadBytes, &out); err != nil { + config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes)) + } + } + return out, resp.StatusCode, nil +} + +func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) { + b, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := c.stream.Do(req) + if err != nil { + config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err) + req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if reqErr != nil { + return nil, err + } + for k, v := range headers { + req2.Header.Set(k, v) + } + return c.fallbackS.Do(req2) + } + return resp, nil +} + +func (c *Client) authHeaders(token string) map[string]string { + headers := make(map[string]string, len(BaseHeaders)+1) + for k, v := range BaseHeaders { + headers[k] = v + } + headers["authorization"] = "Bearer " + token + return headers +} + +func isTokenInvalid(status int, code int, msg string) bool { + msg = strings.ToLower(msg) + if status == http.StatusUnauthorized || status == http.StatusForbidden { + return true + } + if code == 40001 || code == 40002 || code == 40003 { + return true + } + return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") +} + +func intFrom(v any) int { + switch n := v.(type) { + case float64: + return int(n) + case int: + return n + case int64: + return int(n) + default: + return 0 + } +} + +func readResponseBody(resp *http.Response) ([]byte, error) { + encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) + var reader io.Reader = resp.Body + switch encoding { + case "gzip": + gz, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer gz.Close() + reader = gz + case "br": + reader = brotli.NewReader(resp.Body) + } + return io.ReadAll(reader) +} + +func preview(b []byte) string { + s := strings.TrimSpace(string(b)) + if len(s) > 160 { + return s[:160] + } + return s +} + +func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error { + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + if !onLine(scanner.Bytes()) { + break + } + } + if err := scanner.Err(); err != nil { + return err + } + return nil +} diff --git a/internal/deepseek/constants.go b/internal/deepseek/constants.go new file mode 100644 index 0000000..1e7d25f --- /dev/null +++ b/internal/deepseek/constants.go @@ -0,0 +1,26 @@ +package deepseek + +const ( + DeepSeekHost = "chat.deepseek.com" + DeepSeekLoginURL = "https://chat.deepseek.com/api/v0/users/login" + DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create" + DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge" + DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion" +) + +var BaseHeaders = map[string]string{ + "Host": "chat.deepseek.com", + "User-Agent": "DeepSeek/1.6.11 Android/35", + "Accept": "application/json", + "Content-Type": "application/json", + "x-client-platform": "android", + "x-client-version": "1.6.11", + "x-client-locale": "zh_CN", + "accept-charset": "UTF-8", +} + +const ( + KeepAliveTimeout = 5 + StreamIdleTimeout = 30 + MaxKeepaliveCount = 10 +) diff --git a/internal/deepseek/pow.go b/internal/deepseek/pow.go new file mode 100644 index 0000000..f2d1982 --- /dev/null +++ b/internal/deepseek/pow.go @@ -0,0 +1,189 @@ +package deepseek + +import ( + "context" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "math" + "os" + "sync" + + "ds2api/internal/config" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +type PowSolver struct { + wasmPath string + once sync.Once + err error + + runtime wazero.Runtime + compiled wazero.CompiledModule +} + +func NewPowSolver(wasmPath string) *PowSolver { + return &PowSolver{wasmPath: wasmPath} +} + +func (p *PowSolver) init(ctx context.Context) error { + p.once.Do(func() { + wasmBytes, err := os.ReadFile(p.wasmPath) + if err != nil { + p.err = err + return + } + p.runtime = wazero.NewRuntime(ctx) + p.compiled, p.err = p.runtime.CompileModule(ctx, wasmBytes) + }) + return p.err +} + +func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int64, error) { + if err := p.init(ctx); err != nil { + return 0, err + } + algo, _ := challenge["algorithm"].(string) + if algo != "DeepSeekHashV1" { + return 0, errors.New("unsupported algorithm") + } + challengeStr, _ := challenge["challenge"].(string) + salt, _ := challenge["salt"].(string) + signature, _ := challenge["signature"].(string) + targetPath, _ := challenge["target_path"].(string) + _ = signature + _ = targetPath + + difficulty := toFloat64(challenge["difficulty"], 144000) + expireAt := toInt64(challenge["expire_at"], 1680000000) + prefix := salt + "_" + itoa(expireAt) + "_" + + mod, err := p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig()) + if err != nil { + return 0, err + } + defer mod.Close(ctx) + + mem := mod.Memory() + if mem == nil { + return 0, errors.New("wasm memory missing") + } + stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer") + allocFn := mod.ExportedFunction("__wbindgen_export_0") + solveFn := mod.ExportedFunction("wasm_solve") + if stackFn == nil || allocFn == nil || solveFn == nil { + return 0, errors.New("required wasm exports missing") + } + + retPtrs, err := stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32 + if err != nil || len(retPtrs) == 0 { + return 0, errors.New("stack alloc failed") + } + retptr := uint32(retPtrs[0]) + defer stackFn.Call(ctx, 16) + + chPtr, chLen, err := writeUTF8(ctx, allocFn, mem, challengeStr) + if err != nil { + return 0, err + } + prefixPtr, prefixLen, err := writeUTF8(ctx, allocFn, mem, prefix) + if err != nil { + return 0, err + } + + if _, err := solveFn.Call(ctx, + uint64(retptr), + uint64(chPtr), uint64(chLen), + uint64(prefixPtr), uint64(prefixLen), + math.Float64bits(difficulty), + ); err != nil { + return 0, err + } + + statusBytes, ok := mem.Read(retptr, 4) + if !ok { + return 0, errors.New("read status failed") + } + status := int32(binary.LittleEndian.Uint32(statusBytes)) + valueBytes, ok := mem.Read(retptr+8, 8) + if !ok { + return 0, errors.New("read value failed") + } + value := math.Float64frombits(binary.LittleEndian.Uint64(valueBytes)) + if status == 0 { + return 0, errors.New("pow solve failed") + } + return int64(value), nil +} + +func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text string) (uint32, uint32, error) { + data := []byte(text) + res, err := allocFn.Call(ctx, uint64(len(data)), 1) + if err != nil || len(res) == 0 { + return 0, 0, errors.New("alloc failed") + } + ptr := uint32(res[0]) + if !mem.Write(ptr, data) { + return 0, 0, errors.New("mem write failed") + } + return ptr, uint32(len(data)), nil +} + +func BuildPowHeader(challenge map[string]any, answer int64) (string, error) { + payload := map[string]any{ + "algorithm": challenge["algorithm"], + "challenge": challenge["challenge"], + "salt": challenge["salt"], + "answer": answer, + "signature": challenge["signature"], + "target_path": challenge["target_path"], + } + b, err := json.Marshal(payload) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func toFloat64(v any, d float64) float64 { + switch n := v.(type) { + case float64: + return n + case int: + return float64(n) + case int64: + return float64(n) + default: + return d + } +} + +func toInt64(v any, d int64) int64 { + switch n := v.(type) { + case float64: + return int64(n) + case int: + return int64(n) + case int64: + return n + default: + return d + } +} + +func itoa(n int64) string { + b, _ := json.Marshal(n) + return string(b) +} + +func PreloadWASM(wasmPath string) { + solver := NewPowSolver(wasmPath) + if err := solver.init(context.Background()); err != nil { + config.Logger.Warn("[WASM] preload failed", "error", err) + return + } + config.Logger.Info("[WASM] module preloaded", "path", wasmPath) +} diff --git a/internal/deepseek/transport/transport.go b/internal/deepseek/transport/transport.go new file mode 100644 index 0000000..8dbae33 --- /dev/null +++ b/internal/deepseek/transport/transport.go @@ -0,0 +1,59 @@ +package transport + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" + + utls "github.com/refraction-networking/utls" +) + +type Doer interface { + Do(req *http.Request) (*http.Response, error) +} + +type Client struct { + http *http.Client +} + +func New(timeout time.Duration) *Client { + base := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + ForceAttemptHTTP2: false, + MaxIdleConns: 200, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext, + DialTLSContext: safariTLSDialer(), + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + } + return &Client{http: &http.Client{Timeout: timeout, Transport: base}} +} + +func (c *Client) Do(req *http.Request) (*http.Response, error) { + return c.http.Do(req) +} + +func safariTLSDialer() func(ctx context.Context, network, addr string) (net.Conn, error) { + var dialer net.Dialer + return func(ctx context.Context, network, addr string) (net.Conn, error) { + plainConn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + host, _, _ := net.SplitHostPort(addr) + uCfg := &utls.Config{ + ServerName: host, + NextProtos: []string{"http/1.1"}, + } + uConn := utls.UClient(plainConn, uCfg, utls.HelloSafari_Auto) + err = uConn.HandshakeContext(ctx) + if err != nil { + _ = plainConn.Close() + return nil, err + } + return uConn, nil + } +} diff --git a/internal/server/router.go b/internal/server/router.go new file mode 100644 index 0000000..3b57392 --- /dev/null +++ b/internal/server/router.go @@ -0,0 +1,105 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/account" + "ds2api/internal/adapter/claude" + "ds2api/internal/adapter/openai" + "ds2api/internal/admin" + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/webui" +) + +type App struct { + Store *config.Store + Pool *account.Pool + Resolver *auth.Resolver + DS *deepseek.Client + Router http.Handler +} + +func NewApp() *App { + store := config.LoadStore() + pool := account.NewPool(store) + var dsClient *deepseek.Client + resolver := auth.NewResolver(store, pool, func(ctx context.Context, acc config.Account) (string, error) { + return dsClient.Login(ctx, acc) + }) + dsClient = deepseek.NewClient(store, resolver) + deepseek.PreloadWASM(config.WASMPath()) + + openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient} + claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient} + adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient} + webuiHandler := webui.NewHandler() + + r := chi.NewRouter() + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) + r.Use(cors) + r.Use(timeout(0)) + + r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) + }) + r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ready"}`)) + }) + openai.RegisterRoutes(r, openaiHandler) + claude.RegisterRoutes(r, claudeHandler) + r.Route("/admin", func(ar chi.Router) { + admin.RegisterRoutes(ar, adminHandler) + }) + webui.RegisterRoutes(r, webuiHandler) + r.NotFound(func(w http.ResponseWriter, req *http.Request) { + if strings.HasPrefix(req.URL.Path, "/admin/") && webuiHandler.HandleAdminFallback(w, req) { + return + } + http.NotFound(w, req) + }) + + return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r} +} + +func timeout(d time.Duration) func(http.Handler) http.Handler { + if d <= 0 { + return func(next http.Handler) http.Handler { return next } + } + return middleware.Timeout(d) +} + +func cors(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) +} + +func WriteUnhandledError(w http.ResponseWriter, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"type": "api_error", "message": "Internal Server Error", "detail": err.Error()}}) +} diff --git a/internal/sse/parser.go b/internal/sse/parser.go new file mode 100644 index 0000000..b5fd322 --- /dev/null +++ b/internal/sse/parser.go @@ -0,0 +1,224 @@ +package sse + +import ( + "bytes" + "encoding/json" + "strings" +) + +type ContentPart struct { + Text string + Type string +} + +var skipPatterns = []string{ + "quasi_status", "elapsed_secs", "token_usage", "pending_fragment", "conversation_mode", + "fragments/-1/status", "fragments/-2/status", "fragments/-3/status", +} + +func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { + line := strings.TrimSpace(string(raw)) + if line == "" || !strings.HasPrefix(line, "data:") { + return nil, false, false + } + dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if dataStr == "[DONE]" { + return nil, true, true + } + chunk := map[string]any{} + if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { + return nil, false, false + } + return chunk, false, true +} + +func shouldSkipPath(path string) bool { + if path == "response/search_status" { + return true + } + for _, p := range skipPatterns { + if strings.Contains(path, p) { + return true + } + } + return false +} + +func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, currentFragmentType string) ([]ContentPart, bool, string) { + v, ok := chunk["v"] + if !ok { + return nil, false, currentFragmentType + } + path, _ := chunk["p"].(string) + if shouldSkipPath(path) { + return nil, false, currentFragmentType + } + if path == "response/status" { + if s, ok := v.(string); ok && s == "FINISHED" { + return nil, true, currentFragmentType + } + } + newType := currentFragmentType + if path == "response" { + if arr, ok := v.([]any); ok { + for _, it := range arr { + m, ok := it.(map[string]any) + if !ok { + continue + } + if m["p"] == "fragments" && m["o"] == "APPEND" { + if frags, ok := m["v"].([]any); ok { + for _, frag := range frags { + fm, ok := frag.(map[string]any) + if !ok { + continue + } + t, _ := fm["type"].(string) + t = strings.ToUpper(t) + if t == "THINK" || t == "THINKING" { + newType = "thinking" + } else if t == "RESPONSE" { + newType = "text" + } + } + } + } + } + } + } + partType := "text" + switch { + case path == "response/thinking_content": + partType = "thinking" + case path == "response/content": + partType = "text" + case strings.Contains(path, "response/fragments") && strings.Contains(path, "/content"): + partType = newType + case path == "": + if thinkingEnabled { + partType = newType + } + } + parts := make([]ContentPart, 0, 8) + switch val := v.(type) { + case string: + if val == "FINISHED" && (path == "" || path == "status") { + return nil, true, newType + } + if val != "" { + parts = append(parts, ContentPart{Text: val, Type: partType}) + } + case []any: + pp, finished := extractContentRecursive(val, partType) + if finished { + return nil, true, newType + } + parts = append(parts, pp...) + case map[string]any: + resp := val + if wrapped, ok := val["response"].(map[string]any); ok { + resp = wrapped + } + if frags, ok := resp["fragments"].([]any); ok { + for _, item := range frags { + m, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := m["type"].(string) + content, _ := m["content"].(string) + t = strings.ToUpper(t) + if t == "THINK" || t == "THINKING" { + newType = "thinking" + if content != "" { + parts = append(parts, ContentPart{Text: content, Type: "thinking"}) + } + } else if t == "RESPONSE" { + newType = "text" + if content != "" { + parts = append(parts, ContentPart{Text: content, Type: "text"}) + } + } else if content != "" { + parts = append(parts, ContentPart{Text: content, Type: partType}) + } + } + } + } + return parts, false, newType +} + +func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bool) { + parts := make([]ContentPart, 0, len(items)) + for _, it := range items { + m, ok := it.(map[string]any) + if !ok { + continue + } + itemPath, _ := m["p"].(string) + itemV, hasV := m["v"] + if !hasV { + continue + } + if itemPath == "status" { + if s, ok := itemV.(string); ok && s == "FINISHED" { + return nil, true + } + } + if shouldSkipPath(itemPath) { + continue + } + if content, ok := m["content"].(string); ok && content != "" { + typeName, _ := m["type"].(string) + typeName = strings.ToUpper(typeName) + switch typeName { + case "THINK", "THINKING": + parts = append(parts, ContentPart{Text: content, Type: "thinking"}) + case "RESPONSE": + parts = append(parts, ContentPart{Text: content, Type: "text"}) + default: + parts = append(parts, ContentPart{Text: content, Type: defaultType}) + } + continue + } + partType := defaultType + if strings.Contains(itemPath, "thinking") { + partType = "thinking" + } else if strings.Contains(itemPath, "content") || itemPath == "response" || itemPath == "fragments" { + partType = "text" + } + switch v := itemV.(type) { + case string: + if v != "" && v != "FINISHED" { + parts = append(parts, ContentPart{Text: v, Type: partType}) + } + case []any: + for _, inner := range v { + switch x := inner.(type) { + case map[string]any: + ct, _ := x["content"].(string) + if ct == "" { + continue + } + typeName, _ := x["type"].(string) + typeName = strings.ToUpper(typeName) + if typeName == "THINK" || typeName == "THINKING" { + parts = append(parts, ContentPart{Text: ct, Type: "thinking"}) + } else if typeName == "RESPONSE" { + parts = append(parts, ContentPart{Text: ct, Type: "text"}) + } else { + parts = append(parts, ContentPart{Text: ct, Type: partType}) + } + case string: + if x != "" { + parts = append(parts, ContentPart{Text: x, Type: partType}) + } + } + } + } + } + return parts, false +} + +func IsCitation(text string) bool { + return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:")) +} diff --git a/internal/sse/parser_test.go b/internal/sse/parser_test.go new file mode 100644 index 0000000..63d4c08 --- /dev/null +++ b/internal/sse/parser_test.go @@ -0,0 +1,49 @@ +package sse + +import "testing" + +func TestParseDeepSeekSSELine(t *testing.T) { + chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"你好"}`)) + if !ok || done { + t.Fatalf("expected parsed chunk") + } + if chunk["v"] != "你好" { + t.Fatalf("unexpected chunk: %#v", chunk) + } +} + +func TestParseDeepSeekSSELineDone(t *testing.T) { + _, done, ok := ParseDeepSeekSSELine([]byte(`data: [DONE]`)) + if !ok || !done { + t.Fatalf("expected done signal") + } +} + +func TestParseSSEChunkForContentSimple(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text") + if finished { + t.Fatal("expected unfinished") + } + if len(parts) != 1 || parts[0].Text != "hello" || parts[0].Type != "text" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentThinking(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{"p": "response/thinking_content", "v": "think"}, true, "thinking") + if finished { + t.Fatal("expected unfinished") + } + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestIsCitation(t *testing.T) { + if !IsCitation("[citation:1] abc") { + t.Fatal("expected citation true") + } + if IsCitation("normal text") { + t.Fatal("expected citation false") + } +} diff --git a/internal/util/messages.go b/internal/util/messages.go new file mode 100644 index 0000000..86052a9 --- /dev/null +++ b/internal/util/messages.go @@ -0,0 +1,127 @@ +package util + +import ( + "regexp" + "strings" + + "ds2api/internal/config" +) + +var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) + +const ClaudeDefaultModel = "claude-sonnet-4-20250514" + +type Message struct { + Role string `json:"role"` + Content any `json:"content"` +} + +func MessagesPrepare(messages []map[string]any) string { + type block struct { + Role string + Text string + } + processed := make([]block, 0, len(messages)) + for _, m := range messages { + role, _ := m["role"].(string) + text := normalizeContent(m["content"]) + processed = append(processed, block{Role: role, Text: text}) + } + if len(processed) == 0 { + return "" + } + merged := make([]block, 0, len(processed)) + for _, msg := range processed { + if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { + merged[len(merged)-1].Text += "\n\n" + msg.Text + continue + } + merged = append(merged, msg) + } + parts := make([]string, 0, len(merged)) + for i, m := range merged { + switch m.Role { + case "assistant": + parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") + case "user", "system": + if i > 0 { + parts = append(parts, "<|User|>"+m.Text) + } else { + parts = append(parts, m.Text) + } + default: + parts = append(parts, m.Text) + } + } + out := strings.Join(parts, "") + return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) +} + +func normalizeContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + if m["type"] == "text" { + if txt, ok := m["text"].(string); ok { + parts = append(parts, txt) + } + } + } + return strings.Join(parts, "\n") + default: + return "" + } +} + +func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[string]any { + messages, _ := claudeReq["messages"].([]any) + model, _ := claudeReq["model"].(string) + if model == "" { + model = ClaudeDefaultModel + } + mapping := store.ClaudeMapping() + dsModel := mapping["fast"] + if dsModel == "" { + dsModel = "deepseek-chat" + } + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { + if slow := mapping["slow"]; slow != "" { + dsModel = slow + } + } + convertedMessages := make([]any, 0, len(messages)+1) + if system, ok := claudeReq["system"].(string); ok && system != "" { + convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) + } + convertedMessages = append(convertedMessages, messages...) + + out := map[string]any{"model": dsModel, "messages": convertedMessages} + for _, k := range []string{"temperature", "top_p", "stream"} { + if v, ok := claudeReq[k]; ok { + out[k] = v + } + } + if stopSeq, ok := claudeReq["stop_sequences"]; ok { + out["stop"] = stopSeq + } + return out +} + +func EstimateTokens(text string) int { + if text == "" { + return 0 + } + n := len([]rune(text)) / 4 + if n < 1 { + return 1 + } + return n +} diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go new file mode 100644 index 0000000..b8c1304 --- /dev/null +++ b/internal/util/messages_test.go @@ -0,0 +1,69 @@ +package util + +import ( + "testing" + + "ds2api/internal/config" +) + +func TestMessagesPrepareBasic(t *testing.T) { + messages := []map[string]any{{"role": "user", "content": "Hello"}} + got := MessagesPrepare(messages) + if got == "" { + t.Fatal("expected non-empty prompt") + } + if got != "Hello" { + t.Fatalf("unexpected prompt: %q", got) + } +} + +func TestMessagesPrepareRoles(t *testing.T) { + messages := []map[string]any{ + {"role": "system", "content": "You are helper"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + {"role": "user", "content": "How are you"}, + } + got := MessagesPrepare(messages) + if !contains(got, "<|Assistant|>") { + t.Fatalf("expected assistant marker in %q", got) + } + if !contains(got, "<|User|>") { + t.Fatalf("expected user marker in %q", got) + } +} + +func TestConvertClaudeToDeepSeek(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-20250514-slow", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "system": "You are helpful", + "stream": true, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] == "" { + t.Fatal("expected mapped model") + } + msgs, ok := out["messages"].([]any) + if !ok || len(msgs) == 0 { + t.Fatal("expected messages") + } + first, _ := msgs[0].(map[string]any) + if first["role"] != "system" { + t.Fatalf("expected first message system, got %#v", first) + } +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(sub) == 0 || (len(s) > 0 && (indexOf(s, sub) >= 0))) +} + +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go new file mode 100644 index 0000000..ec1b33e --- /dev/null +++ b/internal/util/toolcalls.go @@ -0,0 +1,69 @@ +package util + +import ( + "encoding/json" + "regexp" + "strings" + + "github.com/google/uuid" +) + +var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) + +type ParsedToolCall struct { + Name string `json:"name"` + Input map[string]any `json:"input"` +} + +func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { + if strings.TrimSpace(text) == "" { + return nil + } + m := toolCallPattern.FindStringSubmatch(text) + if len(m) < 2 { + return nil + } + payload := "{" + `"tool_calls":[` + m[1] + "]}" + var obj struct { + ToolCalls []ParsedToolCall `json:"tool_calls"` + } + if err := json.Unmarshal([]byte(payload), &obj); err != nil { + return nil + } + allowed := map[string]struct{}{} + for _, name := range availableToolNames { + allowed[name] = struct{}{} + } + out := make([]ParsedToolCall, 0, len(obj.ToolCalls)) + for _, tc := range obj.ToolCalls { + if tc.Name == "" { + continue + } + if len(allowed) > 0 { + if _, ok := allowed[tc.Name]; !ok { + continue + } + } + if tc.Input == nil { + tc.Input = map[string]any{} + } + out = append(out, tc) + } + return out +} + +func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { + out := make([]map[string]any, 0, len(calls)) + for _, c := range calls { + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go new file mode 100644 index 0000000..87f0ae1 --- /dev/null +++ b/internal/util/toolcalls_test.go @@ -0,0 +1,33 @@ +package util + +import "testing" + +func TestParseToolCalls(t *testing.T) { + text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix` + calls := ParseToolCalls(text, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Name != "search" { + t.Fatalf("unexpected tool name: %s", calls[0].Name) + } +} + +func TestParseToolCallsRejectUnknown(t *testing.T) { + text := `{"tool_calls":[{"name":"unknown","input":{}}]}` + calls := ParseToolCalls(text, []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected 0 calls, got %d", len(calls)) + } +} + +func TestFormatOpenAIToolCalls(t *testing.T) { + formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}}) + if len(formatted) != 1 { + t.Fatalf("expected 1, got %d", len(formatted)) + } + fn, _ := formatted[0]["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn) + } +} diff --git a/internal/webui/handler.go b/internal/webui/handler.go new file mode 100644 index 0000000..f8120a6 --- /dev/null +++ b/internal/webui/handler.go @@ -0,0 +1,81 @@ +package webui + +import ( + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" +) + +const welcomeHTML = ` +
DeepSeek to OpenAI & Claude Compatible API