mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-01 23:15:27 +08:00
feat: Introduce a new Go-based DeepSeek API proxy with adapters for Claude and OpenAI, including SSE parsing and updated build configurations.
This commit is contained in:
31
Dockerfile
31
Dockerfile
@@ -1,33 +1,24 @@
|
||||
# DS2API Docker 镜像
|
||||
# 采用极简、零侵入设计,所有配置通过环境变量传递
|
||||
# 主代码更新时只需重新构建镜像,无需修改 Dockerfile
|
||||
|
||||
FROM node:20 AS webui-builder
|
||||
|
||||
WORKDIR /app/webui
|
||||
|
||||
COPY webui/package.json webui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY webui ./
|
||||
RUN npm run build
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
FROM golang:1.25 AS go-builder
|
||||
WORKDIR /app
|
||||
|
||||
# 安装依赖(利用 Docker 缓存层)
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制整个项目(保留原始目录结构)
|
||||
COPY go.mod go.sum* ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o /out/ds2api ./cmd/ds2api
|
||||
|
||||
# 拷贝 WebUI 构建产物(非 Vercel / Docker 部署可直接使用)
|
||||
FROM debian:bookworm-slim
|
||||
WORKDIR /app
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates wget && rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=go-builder /out/ds2api /usr/local/bin/ds2api
|
||||
COPY --from=go-builder /app/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
COPY --from=go-builder /app/config.example.json /app/config.example.json
|
||||
COPY --from=webui-builder /app/static/admin /app/static/admin
|
||||
|
||||
# 暴露服务端口
|
||||
EXPOSE 5001
|
||||
|
||||
# 启动命令(依赖项目自身的启动逻辑)
|
||||
CMD ["python", "app.py"]
|
||||
CMD ["/usr/local/bin/ds2api"]
|
||||
|
||||
@@ -70,15 +70,12 @@
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 2. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 3. 配置账号
|
||||
# 2. 准备配置
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json,添加 DeepSeek 账号信息
|
||||
|
||||
# 4. 启动服务
|
||||
python dev.py
|
||||
# 3. 启动服务(Go 版本)
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
服务启动后访问 `http://localhost:5001`
|
||||
|
||||
@@ -68,15 +68,12 @@ Convert DeepSeek Web into an **OpenAI & Claude compatible API**, with multi-acco
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 2. Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 3. Configure accounts
|
||||
# 2. Configure accounts
|
||||
cp config.example.json config.json
|
||||
# Edit config.json to add DeepSeek account info
|
||||
|
||||
# 4. Start the service
|
||||
python dev.py
|
||||
# 3. Start the service (Go runtime)
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
Visit `http://localhost:5001` after startup.
|
||||
|
||||
20
api/index.go
Normal file
20
api/index.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
app *server.App
|
||||
)
|
||||
|
||||
func Handler(w http.ResponseWriter, r *http.Request) {
|
||||
once.Do(func() {
|
||||
app = server.NewApp()
|
||||
})
|
||||
app.Router.ServeHTTP(w, r)
|
||||
}
|
||||
23
cmd/ds2api/main.go
Normal file
23
cmd/ds2api/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := server.NewApp()
|
||||
port := strings.TrimSpace(os.Getenv("PORT"))
|
||||
if port == "" {
|
||||
port = "5001"
|
||||
}
|
||||
config.Logger.Info("starting ds2api", "port", port)
|
||||
if err := http.ListenAndServe("0.0.0.0:"+port, app.Router); err != nil {
|
||||
config.Logger.Error("server stopped", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -9,22 +9,12 @@
|
||||
|
||||
services:
|
||||
ds2api:
|
||||
build: .
|
||||
build:
|
||||
context: .
|
||||
target: go-builder
|
||||
image: ds2api:dev
|
||||
container_name: ds2api-dev
|
||||
command: [
|
||||
"uvicorn",
|
||||
"app:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"5001",
|
||||
"--reload",
|
||||
"--reload-dir",
|
||||
"/app",
|
||||
"--log-level",
|
||||
"debug"
|
||||
]
|
||||
command: ["go", "run", "./cmd/ds2api"]
|
||||
ports:
|
||||
- "${PORT:-5001}:5001"
|
||||
env_file:
|
||||
@@ -34,10 +24,7 @@ services:
|
||||
- LOG_LEVEL=DEBUG
|
||||
volumes:
|
||||
# 源代码挂载(开发时实时生效)
|
||||
- ./app.py:/app/app.py:ro
|
||||
- ./core:/app/core:ro
|
||||
- ./routes:/app/routes:ro
|
||||
- ./static:/app/static:ro
|
||||
- ./:/app
|
||||
# 配置文件挂载(便于本地修改)
|
||||
- ./config.json:/app/config.json
|
||||
restart: "no"
|
||||
|
||||
@@ -1,13 +1,3 @@
|
||||
# DS2API 生产环境配置
|
||||
# 使用说明:
|
||||
# 1. 复制 .env.example 为 .env 并填写配置
|
||||
# 2. docker-compose up -d
|
||||
# 3. 主代码更新后:docker-compose up -d --build
|
||||
#
|
||||
# 设计原则:
|
||||
# - 零侵入:所有项目配置通过 .env 文件传递
|
||||
# - 易维护:主代码更新只需重新构建镜像
|
||||
|
||||
services:
|
||||
ds2api:
|
||||
build: .
|
||||
@@ -18,11 +8,10 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# 确保容器内使用正确的主机绑定
|
||||
- HOST=0.0.0.0
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:5001/v1/models')"]
|
||||
test: ["CMD", "wget", "-qO-", "http://localhost:5001/healthz"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
17
go.mod
Normal file
17
go.mod
Normal file
@@ -0,0 +1,17 @@
|
||||
module ds2api
|
||||
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/refraction-networking/utls v1.8.1
|
||||
github.com/tetratelabs/wazero v1.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6 // indirect
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
)
|
||||
16
go.sum
Normal file
16
go.sum
Normal file
@@ -0,0 +1,16 @@
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
|
||||
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
127
internal/account/pool.go
Normal file
127
internal/account/pool.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]bool
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
p := &Pool{store: store, inUse: map[string]bool{}}
|
||||
p.Reset()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) Reset() {
|
||||
accounts := p.store.Accounts()
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
iHas := accounts[i].Token != ""
|
||||
jHas := accounts[j].Token != ""
|
||||
if iHas == jHas {
|
||||
return i < j
|
||||
}
|
||||
return iHas
|
||||
})
|
||||
ids := make([]string, 0, len(accounts))
|
||||
for _, a := range accounts {
|
||||
id := a.Identifier()
|
||||
if id != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.queue = ids
|
||||
p.inUse = map[string]bool{}
|
||||
config.Logger.Info("[init_account_queue] initialized", "total", len(ids))
|
||||
}
|
||||
|
||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if target != "" {
|
||||
for i, id := range p.queue {
|
||||
if id != target {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
return config.Account{}, false
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.inUse[id] = true
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if acc.Token == "" {
|
||||
continue
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.inUse[id] = true
|
||||
return acc, true
|
||||
}
|
||||
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.inUse[id] = true
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) Release(accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if !p.inUse[accountID] {
|
||||
return
|
||||
}
|
||||
delete(p.inUse, accountID)
|
||||
p.queue = append(p.queue, accountID)
|
||||
}
|
||||
|
||||
func (p *Pool) Status() map[string]any {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
available := append([]string{}, p.queue...)
|
||||
inUse := make([]string, 0, len(p.inUse))
|
||||
for id := range p.inUse {
|
||||
inUse = append(inUse, id)
|
||||
}
|
||||
return map[string]any{
|
||||
"available": len(available),
|
||||
"in_use": len(inUse),
|
||||
"total": len(p.store.Accounts()),
|
||||
"available_accounts": available,
|
||||
"in_use_accounts": inUse,
|
||||
}
|
||||
}
|
||||
403
internal/adapter/claude/handler.go
Normal file
403
internal/adapter/claude/handler.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
|
||||
return
|
||||
}
|
||||
|
||||
normalized := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalized
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
_ = searchEnabled
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
|
||||
return
|
||||
}
|
||||
requestPayload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
||||
return
|
||||
}
|
||||
|
||||
fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled)
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
detected := util.ParseToolCalls(fullText, toolNames)
|
||||
if toBool(req["stream"]) {
|
||||
h.writeClaudeStream(w, r, model, normalized, fullText, detected)
|
||||
return
|
||||
}
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if fullThinking != "" {
|
||||
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
|
||||
}
|
||||
stopReason := "end_turn"
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
content = append(content, map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if fullText == "" {
|
||||
fullText = "抱歉,没有生成有效的响应内容。"
|
||||
}
|
||||
content = append(content, map[string]any{"type": "text", "text": fullText})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
|
||||
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
if sys, ok := req["system"].(string); ok {
|
||||
inputTokens += util.EstimateTokens(sys)
|
||||
}
|
||||
for _, item := range messages {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTokens += 2
|
||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok {
|
||||
for _, t := range tools {
|
||||
b, _ := json.Marshal(t)
|
||||
inputTokens += util.EstimateTokens(string(b))
|
||||
}
|
||||
}
|
||||
if inputTokens < 1 {
|
||||
inputTokens = 1
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
|
||||
func collectDeepSeek(resp *http.Response, thinkingEnabled bool) (string, string) {
|
||||
defer resp.Body.Close()
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
chunk, done, ok := sse.ParseDeepSeekSSELine(scanner.Bytes())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return text.String(), thinking.String()
|
||||
}
|
||||
|
||||
func (h *Handler) writeClaudeStream(w http.ResponseWriter, r *http.Request, model string, messages []any, fullText string, detected []util.ParsedToolCall) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "streaming unsupported"}})
|
||||
return
|
||||
}
|
||||
send := func(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages))
|
||||
send(map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
outputTokens := 0
|
||||
stopReason := "end_turn"
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
send(map[string]any{"type": "content_block_start", "index": i, "content_block": map[string]any{"type": "tool_use", "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), "name": tc.Name, "input": tc.Input}})
|
||||
send(map[string]any{"type": "content_block_stop", "index": i})
|
||||
outputTokens += util.EstimateTokens(fmt.Sprintf("%v", tc.Input))
|
||||
}
|
||||
} else {
|
||||
if fullText != "" {
|
||||
send(map[string]any{"type": "content_block_start", "index": 0, "content_block": map[string]any{"type": "text", "text": ""}})
|
||||
send(map[string]any{"type": "content_block_delta", "index": 0, "delta": map[string]any{"type": "text_delta", "text": fullText}})
|
||||
send(map[string]any{"type": "content_block_stop", "index": 0})
|
||||
outputTokens = util.EstimateTokens(fullText)
|
||||
}
|
||||
}
|
||||
send(map[string]any{"type": "message_delta", "delta": map[string]any{"stop_reason": stopReason, "stop_sequence": nil}, "usage": map[string]any{"output_tokens": outputTokens}})
|
||||
send(map[string]any{"type": "message_stop"})
|
||||
_ = r
|
||||
}
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := b["type"].(string)
|
||||
if typeStr == "text" {
|
||||
if t, ok := b["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, fmt.Sprintf("%v", b["content"]))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
}
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schema, _ := json.Marshal(m["input_schema"])
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := m["name"].(string); ok && name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toBool(v any) bool {
|
||||
b, _ := v.(bool)
|
||||
return b
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
413
internal/adapter/openai/handler.go
Normal file
413
internal/adapter/openai/handler.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"error": detail})
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"error": fmt.Sprintf("Model '%s' is not available.", model)})
|
||||
return
|
||||
}
|
||||
|
||||
messages := normalizeMessages(messagesRaw)
|
||||
toolNames := []string{}
|
||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(messages)
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": "invalid token."})
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": "Failed to get PoW (invalid token or unknown error)."})
|
||||
return
|
||||
}
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": "Failed to get completion."})
|
||||
return
|
||||
}
|
||||
if toBool(req["stream"]) {
|
||||
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, resp.StatusCode, map[string]any{"error": string(body)})
|
||||
return
|
||||
}
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
}
|
||||
_ = ctx
|
||||
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
|
||||
chunk, done, ok := sse.ParseDeepSeekSSELine(line)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if done {
|
||||
return false
|
||||
}
|
||||
if _, hasErr := chunk["error"]; hasErr {
|
||||
return false
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
return false
|
||||
}
|
||||
for _, p := range parts {
|
||||
if searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if thinkingEnabled && finalThinking != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, resp.StatusCode, map[string]any{"error": string(body)})
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": "streaming unsupported"})
|
||||
return
|
||||
}
|
||||
|
||||
lines := make(chan []byte, 128)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
b := append([]byte{}, scanner.Bytes()...)
|
||||
lines <- b
|
||||
}
|
||||
close(lines)
|
||||
done <- scanner.Err()
|
||||
}()
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
}
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
lastContent := time.Now()
|
||||
hasContent := false
|
||||
keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
sendChunk := func(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
sendDone := func() {
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
finalize := func(finishReason string) {
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": map[string]any{"tool_calls": util.FormatOpenAIToolCalls(detected)}, "index": 0}},
|
||||
})
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
sendDone()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-keepaliveTicker.C:
|
||||
if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
flusher.Flush()
|
||||
case line, ok := <-lines:
|
||||
if !ok {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
|
||||
if !parsed {
|
||||
continue
|
||||
}
|
||||
if doneSignal {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
if _, hasErr := chunk["error"]; hasErr || chunk["code"] == "content_filter" {
|
||||
finalize("content_filter")
|
||||
return
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
newChoices := make([]map[string]any, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
delta := map[string]any{}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if thinkingEnabled {
|
||||
thinking.WriteString(p.Text)
|
||||
delta["reasoning_content"] = p.Text
|
||||
}
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
delta["content"] = p.Text
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})
|
||||
}
|
||||
}
|
||||
if len(newChoices) > 0 {
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": newChoices,
|
||||
})
|
||||
}
|
||||
case <-done:
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMessages(raw []any) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]any)
|
||||
if ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
desc = "No description available"
|
||||
}
|
||||
b, _ := json.Marshal(schema)
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
||||
}
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }"
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
old, _ := messages[i]["content"].(string)
|
||||
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
||||
return messages, names
|
||||
}
|
||||
}
|
||||
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func toBool(v any) bool {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
890
internal/admin/handler.go
Normal file
890
internal/admin/handler.go
Normal file
@@ -0,0 +1,890 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
DS *deepseek.Client
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
|
||||
r.Post("/login", h.login)
|
||||
r.Get("/verify", h.verify)
|
||||
r.Group(func(pr chi.Router) {
|
||||
pr.Use(h.requireAdmin)
|
||||
pr.Get("/vercel/config", h.getVercelConfig)
|
||||
pr.Get("/config", h.getConfig)
|
||||
pr.Post("/config", h.updateConfig)
|
||||
pr.Post("/keys", h.addKey)
|
||||
pr.Delete("/keys/{key}", h.deleteKey)
|
||||
pr.Get("/accounts", h.listAccounts)
|
||||
pr.Post("/accounts", h.addAccount)
|
||||
pr.Delete("/accounts/{identifier}", h.deleteAccount)
|
||||
pr.Get("/queue/status", h.queueStatus)
|
||||
pr.Post("/accounts/test", h.testSingleAccount)
|
||||
pr.Post("/accounts/test-all", h.testAllAccounts)
|
||||
pr.Post("/import", h.batchImport)
|
||||
pr.Post("/test", h.testAPI)
|
||||
pr.Post("/vercel/sync", h.syncVercel)
|
||||
pr.Get("/vercel/status", h.vercelStatus)
|
||||
pr.Get("/export", h.exportConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) requireAdmin(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := authn.VerifyAdminRequest(r); err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) login(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
adminKey, _ := req["admin_key"].(string)
|
||||
expireHours := intFrom(req["expire_hours"])
|
||||
if expireHours <= 0 {
|
||||
expireHours = 24
|
||||
}
|
||||
if adminKey != authn.AdminKey() {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"})
|
||||
return
|
||||
}
|
||||
token, err := authn.CreateJWT(expireHours)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600})
|
||||
}
|
||||
|
||||
func (h *Handler) verify(w http.ResponseWriter, r *http.Request) {
|
||||
header := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if !strings.HasPrefix(strings.ToLower(header), "bearer ") {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "No credentials provided"})
|
||||
return
|
||||
}
|
||||
token := strings.TrimSpace(header[7:])
|
||||
payload, err := authn.VerifyJWT(token)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
exp, _ := payload["exp"].(float64)
|
||||
remaining := int64(exp) - time.Now().Unix()
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"valid": true, "expires_at": int64(exp), "remaining_seconds": remaining})
|
||||
}
|
||||
|
||||
func (h *Handler) getVercelConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"has_token": strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) != "",
|
||||
"project_id": strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")),
|
||||
"team_id": nilIfEmpty(strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID"))),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"claude_mapping": func() map[string]string {
|
||||
if len(snap.ClaudeMapping) > 0 {
|
||||
return snap.ClaudeMapping
|
||||
}
|
||||
return snap.ClaudeModelMap
|
||||
}(),
|
||||
}
|
||||
accounts := make([]map[string]any, 0, len(snap.Accounts))
|
||||
for _, acc := range snap.Accounts {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, map[string]any{
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": strings.TrimSpace(acc.Password) != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
})
|
||||
}
|
||||
safe["accounts"] = accounts
|
||||
writeJSON(w, http.StatusOK, safe)
|
||||
}
|
||||
|
||||
func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
old := h.Store.Snapshot()
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
c.Keys = keys
|
||||
}
|
||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]config.Account{}
|
||||
for _, a := range old.Accounts {
|
||||
existing[a.Identifier()] = a
|
||||
}
|
||||
accounts := make([]config.Account, 0, len(accountsRaw))
|
||||
for _, item := range accountsRaw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if prev, ok := existing[id]; ok {
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
acc.Password = prev.Password
|
||||
}
|
||||
if strings.TrimSpace(acc.Token) == "" {
|
||||
acc.Token = prev.Token
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
c.Accounts = accounts
|
||||
}
|
||||
if m, ok := req["claude_mapping"].(map[string]any); ok {
|
||||
newMap := map[string]string{}
|
||||
for k, v := range m {
|
||||
newMap[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
c.ClaudeMapping = newMap
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
|
||||
}
|
||||
|
||||
func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key, _ := req["key"].(string)
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, k := range c.Keys {
|
||||
if k == key {
|
||||
return fmt.Errorf("Key 已存在")
|
||||
}
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, k := range c.Keys {
|
||||
if k == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("Key 不存在")
|
||||
}
|
||||
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
page := intFromQuery(r, "page", 1)
|
||||
pageSize := intFromQuery(r, "page_size", 10)
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
total := len(accounts)
|
||||
reverseAccounts(accounts)
|
||||
totalPages := 1
|
||||
if total > 0 {
|
||||
totalPages = (total + pageSize - 1) / pageSize
|
||||
}
|
||||
start := (page - 1) * pageSize
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
end := start + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
}
|
||||
|
||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
acc := toAccount(req)
|
||||
if acc.Identifier() == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, a := range c.Accounts {
|
||||
if acc.Email != "" && a.Email == acc.Email {
|
||||
return fmt.Errorf("邮箱已存在")
|
||||
}
|
||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
||||
return fmt.Errorf("手机号已存在")
|
||||
}
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, a := range c.Accounts {
|
||||
if a.Email == identifier || a.Mobile == identifier {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("账号不存在")
|
||||
}
|
||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, h.Pool.Status())
|
||||
}
|
||||
|
||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
identifier, _ := req["identifier"].(string)
|
||||
if strings.TrimSpace(identifier) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"})
|
||||
return
|
||||
}
|
||||
acc, ok := h.Store.FindAccount(identifier)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
message, _ := req["message"].(string)
|
||||
result := h.testAccount(r.Context(), acc, model, message)
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
model, _ := req["model"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
if len(accounts) == 0 {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
||||
return
|
||||
}
|
||||
results := make([]map[string]any, 0, len(accounts))
|
||||
success := 0
|
||||
for _, acc := range accounts {
|
||||
res := h.testAccount(r.Context(), acc, model, "")
|
||||
if ok, _ := res["success"].(bool); ok {
|
||||
success++
|
||||
}
|
||||
results = append(results, res)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||
}
|
||||
|
||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||
start := time.Now()
|
||||
result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model}
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
if token == "" {
|
||||
newToken, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
}
|
||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
newToken, loginErr := h.DS.Login(ctx, acc)
|
||||
if loginErr != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
authCtx.DeepSeekToken = token
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(message) == "" {
|
||||
result["success"] = true
|
||||
result["message"] = "API 测试成功(仅会话创建)"
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
return result
|
||||
}
|
||||
thinking, search, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
thinking, search = false, false
|
||||
}
|
||||
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "获取 PoW 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
||||
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
||||
if err != nil {
|
||||
result["message"] = "请求失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
||||
return result
|
||||
}
|
||||
text := strings.Builder{}
|
||||
think := strings.Builder{}
|
||||
currentType := "text"
|
||||
if thinking {
|
||||
currentType = "thinking"
|
||||
}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
chunk, done, parsed := sse.ParseDeepSeekSSELine(scanner.Bytes())
|
||||
if !parsed {
|
||||
continue
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinking, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type == "thinking" {
|
||||
think.WriteString(p.Text)
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
result["success"] = true
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
if text.Len() > 0 {
|
||||
result["message"] = text.String()
|
||||
} else {
|
||||
result["message"] = "(无回复内容)"
|
||||
}
|
||||
if think.Len() > 0 {
|
||||
result["thinking"] = think.String()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
|
||||
return
|
||||
}
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := req["keys"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, k := range c.Keys {
|
||||
existing[k] = true
|
||||
}
|
||||
for _, k := range keys {
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", k))
|
||||
if key == "" || existing[key] {
|
||||
continue
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
existing[key] = true
|
||||
importedKeys++
|
||||
}
|
||||
}
|
||||
if accounts, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, a := range c.Accounts {
|
||||
existing[a.Identifier()] = true
|
||||
}
|
||||
for _, item := range accounts {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if id == "" || existing[id] {
|
||||
continue
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
existing[id] = true
|
||||
importedAccounts++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
|
||||
}
|
||||
|
||||
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
model, _ := req["model"].(string)
|
||||
message, _ := req["message"].(string)
|
||||
apiKey, _ := req["api_key"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
if message == "" {
|
||||
message = "你好"
|
||||
}
|
||||
if apiKey == "" {
|
||||
keys := h.Store.Snapshot().Keys
|
||||
if len(keys) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
||||
return
|
||||
}
|
||||
apiKey = keys[0]
|
||||
}
|
||||
host := r.Host
|
||||
scheme := "http"
|
||||
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
||||
scheme = "https"
|
||||
}
|
||||
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
||||
b, _ := json.Marshal(payload)
|
||||
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var parsed any
|
||||
_ = json.Unmarshal(body, &parsed)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||
}
|
||||
|
||||
func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
vercelToken, _ := req["vercel_token"].(string)
|
||||
projectID, _ := req["project_id"].(string)
|
||||
teamID, _ := req["team_id"].(string)
|
||||
autoValidate := true
|
||||
if v, ok := req["auto_validate"].(bool); ok {
|
||||
autoValidate = v
|
||||
}
|
||||
saveCreds := true
|
||||
if v, ok := req["save_credentials"].(bool); ok {
|
||||
saveCreds = v
|
||||
}
|
||||
usePreconfig := vercelToken == "__USE_PRECONFIG__" || strings.TrimSpace(vercelToken) == ""
|
||||
if usePreconfig {
|
||||
vercelToken = strings.TrimSpace(os.Getenv("VERCEL_TOKEN"))
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
projectID = strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID"))
|
||||
}
|
||||
if strings.TrimSpace(teamID) == "" {
|
||||
teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID"))
|
||||
}
|
||||
if vercelToken == "" || projectID == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"})
|
||||
return
|
||||
}
|
||||
validated, failed := 0, []string{}
|
||||
if autoValidate {
|
||||
for _, acc := range h.Store.Snapshot().Accounts {
|
||||
if strings.TrimSpace(acc.Token) != "" {
|
||||
continue
|
||||
}
|
||||
token, err := h.DS.Login(r.Context(), acc)
|
||||
if err != nil {
|
||||
failed = append(failed, acc.Identifier())
|
||||
} else {
|
||||
validated++
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
cfgJSON, _, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON))
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
params := url.Values{}
|
||||
if teamID != "" {
|
||||
params.Set("teamId", teamID)
|
||||
}
|
||||
headers := map[string]string{"Authorization": "Bearer " + vercelToken}
|
||||
envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil)
|
||||
if err != nil || status != http.StatusOK {
|
||||
writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"})
|
||||
return
|
||||
}
|
||||
envs, _ := envResp["envs"].([]any)
|
||||
existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON")
|
||||
if existingEnvID != "" {
|
||||
_, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64})
|
||||
} else {
|
||||
_, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}})
|
||||
}
|
||||
if err != nil || (status != http.StatusOK && status != http.StatusCreated) {
|
||||
writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"})
|
||||
return
|
||||
}
|
||||
savedCreds := []string{}
|
||||
if saveCreds && !usePreconfig {
|
||||
creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}}
|
||||
if teamID != "" {
|
||||
creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID})
|
||||
}
|
||||
for _, kv := range creds {
|
||||
id := findEnvID(envs, kv[0])
|
||||
if id != "" {
|
||||
_, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]})
|
||||
} else {
|
||||
_, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}})
|
||||
}
|
||||
if status == http.StatusOK || status == http.StatusCreated {
|
||||
savedCreds = append(savedCreds, kv[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil)
|
||||
manual := true
|
||||
deployURL := ""
|
||||
if status == http.StatusOK {
|
||||
if link, ok := projectResp["link"].(map[string]any); ok {
|
||||
if linkType, _ := link["type"].(string); linkType == "github" {
|
||||
repoID := intFrom(link["repoId"])
|
||||
ref, _ := link["productionBranch"].(string)
|
||||
if ref == "" {
|
||||
ref = "main"
|
||||
}
|
||||
depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}})
|
||||
if depStatus == http.StatusOK || depStatus == http.StatusCreated {
|
||||
deployURL, _ = depResp["url"].(string)
|
||||
manual = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix())
|
||||
result := map[string]any{"success": true, "validated_accounts": validated}
|
||||
if manual {
|
||||
result["message"] = "配置已同步到 Vercel,请手动触发重新部署"
|
||||
result["manual_deploy_required"] = true
|
||||
} else {
|
||||
result["message"] = "配置已同步,正在重新部署..."
|
||||
result["deployment_url"] = deployURL
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
result["failed_accounts"] = failed
|
||||
}
|
||||
if len(savedCreds) > 0 {
|
||||
result["saved_credentials"] = savedCreds
|
||||
}
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
current := h.computeSyncHash()
|
||||
synced := snap.VercelSyncHash != "" && snap.VercelSyncHash == current
|
||||
writeJSON(w, http.StatusOK, map[string]any{"synced": synced, "last_sync_time": nilIfZero(snap.VercelSyncTime), "has_synced_before": snap.VercelSyncHash != ""})
|
||||
}
|
||||
|
||||
func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
jsonStr, b64, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64})
|
||||
}
|
||||
|
||||
func (h *Handler) computeSyncHash() string {
|
||||
snap := h.Store.Snapshot()
|
||||
syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}}
|
||||
accounts := make([]map[string]any, 0, len(snap.Accounts))
|
||||
for _, a := range snap.Accounts {
|
||||
m := map[string]any{}
|
||||
if a.Email != "" {
|
||||
m["email"] = a.Email
|
||||
}
|
||||
if a.Mobile != "" {
|
||||
m["mobile"] = a.Mobile
|
||||
}
|
||||
if a.Password != "" {
|
||||
m["password"] = a.Password
|
||||
}
|
||||
accounts = append(accounts, m)
|
||||
}
|
||||
sort.Slice(accounts, func(i, j int) bool {
|
||||
ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"])
|
||||
aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"])
|
||||
return ai < aj
|
||||
})
|
||||
syncable["accounts"] = accounts
|
||||
b, _ := json.Marshal(syncable)
|
||||
sum := md5.Sum(b)
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
func vercelRequest(ctx context.Context, client *http.Client, method, endpoint string, params url.Values, headers map[string]string, body any) (map[string]any, int, error) {
|
||||
if len(params) > 0 {
|
||||
endpoint += "?" + params.Encode()
|
||||
}
|
||||
var reader io.Reader
|
||||
if body != nil {
|
||||
b, _ := json.Marshal(body)
|
||||
reader = bytes.NewReader(b)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, endpoint, reader)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
parsed := map[string]any{}
|
||||
_ = json.Unmarshal(b, &parsed)
|
||||
if len(parsed) == 0 {
|
||||
parsed["raw"] = string(b)
|
||||
}
|
||||
return parsed, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func findEnvID(envs []any, key string) string {
|
||||
for _, item := range envs {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if k, _ := m["key"].(string); k == key {
|
||||
id, _ := m["id"].(string)
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func reverseAccounts(a []config.Account) {
|
||||
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
}
|
||||
|
||||
func intFromQuery(r *http.Request, key string, d int) int {
|
||||
v := r.URL.Query().Get(key)
|
||||
if v == "" {
|
||||
return d
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func intFrom(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func nilIfEmpty(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func nilIfZero(v int64) any {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func toStringSlice(v any) ([]string, bool) {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
out = append(out, strings.TrimSpace(fmt.Sprintf("%v", item)))
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func toAccount(m map[string]any) config.Account {
|
||||
return config.Account{Email: strings.TrimSpace(fmt.Sprintf("%v", m["email"])), Mobile: strings.TrimSpace(fmt.Sprintf("%v", m["mobile"])), Password: strings.TrimSpace(fmt.Sprintf("%v", m["password"])), Token: strings.TrimSpace(fmt.Sprintf("%v", m["token"]))}
|
||||
}
|
||||
|
||||
func statusOr(v int, d int) int {
|
||||
if v == 0 {
|
||||
return d
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
113
internal/auth/admin.go
Normal file
113
internal/auth/admin.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AdminKey() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
|
||||
return v
|
||||
}
|
||||
return "your-admin-secret-key"
|
||||
}
|
||||
|
||||
func jwtSecret() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" {
|
||||
return v
|
||||
}
|
||||
return AdminKey()
|
||||
}
|
||||
|
||||
func jwtExpireHours() int {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 24
|
||||
}
|
||||
|
||||
func CreateJWT(expireHours int) (string, error) {
|
||||
if expireHours <= 0 {
|
||||
expireHours = jwtExpireHours()
|
||||
}
|
||||
header := map[string]any{"alg": "HS256", "typ": "JWT"}
|
||||
payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"}
|
||||
h, _ := json.Marshal(header)
|
||||
p, _ := json.Marshal(payload)
|
||||
headerB64 := rawB64Encode(h)
|
||||
payloadB64 := rawB64Encode(p)
|
||||
msg := headerB64 + "." + payloadB64
|
||||
sig := signHS256(msg)
|
||||
return msg + "." + rawB64Encode(sig), nil
|
||||
}
|
||||
|
||||
func VerifyJWT(token string) (map[string]any, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
msg := parts[0] + "." + parts[1]
|
||||
expected := signHS256(msg)
|
||||
actual, err := rawB64Decode(parts[2])
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid signature")
|
||||
}
|
||||
if !hmac.Equal(expected, actual) {
|
||||
return nil, errors.New("invalid signature")
|
||||
}
|
||||
payloadBytes, err := rawB64Decode(parts[1])
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid payload")
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, errors.New("invalid payload")
|
||||
}
|
||||
exp, _ := payload["exp"].(float64)
|
||||
if int64(exp) < time.Now().Unix() {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func VerifyAdminRequest(r *http.Request) error {
|
||||
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return errors.New("authentication required")
|
||||
}
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token == "" {
|
||||
return errors.New("authentication required")
|
||||
}
|
||||
if token == AdminKey() {
|
||||
return nil
|
||||
}
|
||||
if _, err := VerifyJWT(token); err == nil {
|
||||
return nil
|
||||
}
|
||||
return errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
func signHS256(msg string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(jwtSecret()))
|
||||
_, _ = h.Write([]byte(msg))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func rawB64Encode(b []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func rawB64Decode(s string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(s)
|
||||
}
|
||||
29
internal/auth/admin_test.go
Normal file
29
internal/auth/admin_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJWTCreateVerify(t *testing.T) {
|
||||
token, err := CreateJWT(1)
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt failed: %v", err)
|
||||
}
|
||||
payload, err := VerifyJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("verify jwt failed: %v", err)
|
||||
}
|
||||
if payload["role"] != "admin" {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAdminRequest(t *testing.T) {
|
||||
token, _ := CreateJWT(1)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/admin/config", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
if err := VerifyAdminRequest(req); err != nil {
|
||||
t.Fatalf("expected token accepted: %v", err)
|
||||
}
|
||||
}
|
||||
150
internal/auth/request.go
Normal file
150
internal/auth/request.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const authCtxKey ctxKey = "auth_context"
|
||||
|
||||
var (
|
||||
ErrUnauthorized = errors.New("unauthorized: missing Bearer token")
|
||||
ErrNoAccount = errors.New("no accounts configured or all accounts are busy")
|
||||
)
|
||||
|
||||
type RequestAuth struct {
|
||||
UseConfigToken bool
|
||||
DeepSeekToken string
|
||||
AccountID string
|
||||
Account config.Account
|
||||
TriedAccounts map[string]bool
|
||||
resolver *Resolver
|
||||
}
|
||||
|
||||
type LoginFunc func(ctx context.Context, acc config.Account) (string, error)
|
||||
|
||||
type Resolver struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
Login LoginFunc
|
||||
}
|
||||
|
||||
func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver {
|
||||
return &Resolver{Store: store, Pool: pool, Login: login}
|
||||
}
|
||||
|
||||
func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerKey := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||
ctx := req.Context()
|
||||
if !r.Store.HasAPIKey(callerKey) {
|
||||
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
|
||||
}
|
||||
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
|
||||
acc, ok := r.Pool.Acquire(target, nil)
|
||||
if !ok {
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: true,
|
||||
AccountID: acc.Identifier(),
|
||||
Account: acc,
|
||||
TriedAccounts: map[string]bool{},
|
||||
resolver: r,
|
||||
}
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func WithAuth(ctx context.Context, a *RequestAuth) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (*RequestAuth, bool) {
|
||||
v := ctx.Value(authCtxKey)
|
||||
a, ok := v.(*RequestAuth)
|
||||
return a, ok
|
||||
}
|
||||
|
||||
func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error {
|
||||
token, err := r.Login(ctx, a.Account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.Account.Token = token
|
||||
a.DeepSeekToken = token
|
||||
return r.Store.UpdateAccountToken(a.AccountID, token)
|
||||
}
|
||||
|
||||
func (r *Resolver) RefreshToken(ctx context.Context, a *RequestAuth) bool {
|
||||
if !a.UseConfigToken || a.AccountID == "" {
|
||||
return false
|
||||
}
|
||||
_ = r.Store.UpdateAccountToken(a.AccountID, "")
|
||||
a.Account.Token = ""
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
config.Logger.Error("[refresh_token] failed", "account", a.AccountID, "error", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Resolver) MarkTokenInvalid(a *RequestAuth) {
|
||||
if !a.UseConfigToken || a.AccountID == "" {
|
||||
return
|
||||
}
|
||||
a.Account.Token = ""
|
||||
a.DeepSeekToken = ""
|
||||
_ = r.Store.UpdateAccountToken(a.AccountID, "")
|
||||
}
|
||||
|
||||
func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool {
|
||||
if !a.UseConfigToken {
|
||||
return false
|
||||
}
|
||||
if a.TriedAccounts == nil {
|
||||
a.TriedAccounts = map[string]bool{}
|
||||
}
|
||||
if a.AccountID != "" {
|
||||
a.TriedAccounts[a.AccountID] = true
|
||||
r.Pool.Release(a.AccountID)
|
||||
}
|
||||
acc, ok := r.Pool.Acquire("", a.TriedAccounts)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
a.Account = acc
|
||||
a.AccountID = acc.Identifier()
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Resolver) Release(a *RequestAuth) {
|
||||
if a == nil || !a.UseConfigToken || a.AccountID == "" {
|
||||
return
|
||||
}
|
||||
r.Pool.Release(a.AccountID)
|
||||
}
|
||||
360
internal/config/config.go
Normal file
360
internal/config/config.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var Logger = newLogger()
|
||||
|
||||
func newLogger() *slog.Logger {
|
||||
level := new(slog.LevelVar)
|
||||
switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
|
||||
case "DEBUG":
|
||||
level.Set(slog.LevelDebug)
|
||||
case "WARN":
|
||||
level.Set(slog.LevelWarn)
|
||||
case "ERROR":
|
||||
level.Set(slog.LevelError)
|
||||
default:
|
||||
level.Set(slog.LevelInfo)
|
||||
}
|
||||
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
|
||||
return slog.New(h)
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
return strings.TrimSpace(a.Email)
|
||||
}
|
||||
return strings.TrimSpace(a.Mobile)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]any{}
|
||||
for k, v := range c.AdditionalFields {
|
||||
m[k] = v
|
||||
}
|
||||
if len(c.Keys) > 0 {
|
||||
m["keys"] = c.Keys
|
||||
}
|
||||
if len(c.Accounts) > 0 {
|
||||
m["accounts"] = c.Accounts
|
||||
}
|
||||
if len(c.ClaudeMapping) > 0 {
|
||||
m["claude_mapping"] = c.ClaudeMapping
|
||||
}
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
m["claude_model_mapping"] = c.ClaudeModelMap
|
||||
}
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
if c.VercelSyncTime != 0 {
|
||||
m["_vercel_sync_time"] = c.VercelSyncTime
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
raw := map[string]json.RawMessage{}
|
||||
if err := json.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
c.AdditionalFields = map[string]any{}
|
||||
for k, v := range raw {
|
||||
switch k {
|
||||
case "keys":
|
||||
_ = json.Unmarshal(v, &c.Keys)
|
||||
case "accounts":
|
||||
_ = json.Unmarshal(v, &c.Accounts)
|
||||
case "claude_mapping":
|
||||
_ = json.Unmarshal(v, &c.ClaudeMapping)
|
||||
case "claude_model_mapping":
|
||||
_ = json.Unmarshal(v, &c.ClaudeModelMap)
|
||||
case "_vercel_sync_hash":
|
||||
_ = json.Unmarshal(v, &c.VercelSyncHash)
|
||||
case "_vercel_sync_time":
|
||||
_ = json.Unmarshal(v, &c.VercelSyncTime)
|
||||
default:
|
||||
var anyVal any
|
||||
if err := json.Unmarshal(v, &anyVal); err == nil {
|
||||
c.AdditionalFields[k] = anyVal
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
}
|
||||
for k, v := range c.AdditionalFields {
|
||||
clone.AdditionalFields[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
cfg Config
|
||||
path string
|
||||
fromEnv bool
|
||||
}
|
||||
|
||||
func BaseDir() string {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return cwd
|
||||
}
|
||||
|
||||
func IsVercel() bool {
|
||||
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
|
||||
}
|
||||
|
||||
func ResolvePath(envKey, defaultRel string) string {
|
||||
raw := strings.TrimSpace(os.Getenv(envKey))
|
||||
if raw != "" {
|
||||
if filepath.IsAbs(raw) {
|
||||
return raw
|
||||
}
|
||||
return filepath.Join(BaseDir(), raw)
|
||||
}
|
||||
return filepath.Join(BaseDir(), defaultRel)
|
||||
}
|
||||
|
||||
func ConfigPath() string {
|
||||
return ResolvePath("DS2API_CONFIG_PATH", "config.json")
|
||||
}
|
||||
|
||||
func WASMPath() string {
|
||||
return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
|
||||
}
|
||||
|
||||
func StaticAdminDir() string {
|
||||
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
|
||||
}
|
||||
|
||||
func LoadStore() *Store {
|
||||
cfg, fromEnv, err := loadConfig()
|
||||
if err != nil {
|
||||
Logger.Warn("[config] load failed", "error", err)
|
||||
}
|
||||
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
|
||||
Logger.Warn("[config] empty config loaded")
|
||||
}
|
||||
return &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
|
||||
}
|
||||
|
||||
func loadConfig() (Config, bool, error) {
|
||||
rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
|
||||
if rawCfg == "" {
|
||||
rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
|
||||
}
|
||||
if rawCfg != "" {
|
||||
cfg, err := parseConfigString(rawCfg)
|
||||
return cfg, true, err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(ConfigPath())
|
||||
if err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
return cfg, false, nil
|
||||
}
|
||||
|
||||
func parseConfigString(raw string) (Config, error) {
|
||||
var cfg Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() Config {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.Clone()
|
||||
}
|
||||
|
||||
func (s *Store) HasAPIKey(k string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, key := range s.cfg.Keys {
|
||||
if key == k {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Store) Keys() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.cfg.Keys)
|
||||
}
|
||||
|
||||
func (s *Store) Accounts() []Account {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.cfg.Accounts)
|
||||
}
|
||||
|
||||
func (s *Store) FindAccount(identifier string) (Account, bool) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, acc := range s.cfg.Accounts {
|
||||
if acc.Identifier() == identifier {
|
||||
return acc, true
|
||||
}
|
||||
}
|
||||
return Account{}, false
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for i := range s.cfg.Accounts {
|
||||
if s.cfg.Accounts[i].Identifier() == identifier {
|
||||
s.cfg.Accounts[i].Token = token
|
||||
return s.saveLocked()
|
||||
}
|
||||
}
|
||||
return errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *Store) Replace(cfg Config) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cfg = cfg.Clone()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Update(mutator func(*Config) error) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.cfg.Clone()
|
||||
if err := mutator(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
s.cfg = cfg
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Save() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.fromEnv {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
if s.fromEnv {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
func (s *Store) IsEnvBacked() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.fromEnv
|
||||
}
|
||||
|
||||
func (s *Store) SetVercelSync(hash string, ts int64) error {
|
||||
return s.Update(func(c *Config) error {
|
||||
c.VercelSyncHash = hash
|
||||
c.VercelSyncTime = ts
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) ExportJSONAndBase64() (string, string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
b, err := json.Marshal(s.cfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return string(b), base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (s *Store) ClaudeMapping() map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if len(s.cfg.ClaudeModelMap) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeModelMap)
|
||||
}
|
||||
if len(s.cfg.ClaudeMapping) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeMapping)
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-chat"}
|
||||
}
|
||||
55
internal/config/models.go
Normal file
55
internal/config/models.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []any `json:"permission,omitempty"`
|
||||
}
|
||||
|
||||
var DeepSeekModels = []ModelInfo{
|
||||
{ID: "deepseek-chat", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
{ID: "deepseek-reasoner", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
{ID: "deepseek-chat-search", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
{ID: "deepseek-reasoner-search", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
}
|
||||
|
||||
var ClaudeModels = []ModelInfo{
|
||||
{ID: "claude-sonnet-4-20250514", Object: "model", Created: 1715635200, OwnedBy: "anthropic"},
|
||||
{ID: "claude-sonnet-4-20250514-fast", Object: "model", Created: 1715635200, OwnedBy: "anthropic"},
|
||||
{ID: "claude-sonnet-4-20250514-slow", Object: "model", Created: 1715635200, OwnedBy: "anthropic"},
|
||||
}
|
||||
|
||||
func GetModelConfig(model string) (thinking bool, search bool, ok bool) {
|
||||
switch lower(model) {
|
||||
case "deepseek-chat":
|
||||
return false, false, true
|
||||
case "deepseek-reasoner":
|
||||
return true, false, true
|
||||
case "deepseek-chat-search":
|
||||
return false, true, true
|
||||
case "deepseek-reasoner-search":
|
||||
return true, true, true
|
||||
default:
|
||||
return false, false, false
|
||||
}
|
||||
}
|
||||
|
||||
func lower(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
b[i] = c + 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func OpenAIModelsResponse() map[string]any {
|
||||
return map[string]any{"object": "list", "data": DeepSeekModels}
|
||||
}
|
||||
|
||||
func ClaudeModelsResponse() map[string]any {
|
||||
return map[string]any{"object": "list", "data": ClaudeModels}
|
||||
}
|
||||
342
internal/deepseek/client.go
Normal file
342
internal/deepseek/client.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
regular trans.Doer
|
||||
stream trans.Doer
|
||||
fallback *http.Client
|
||||
fallbackS *http.Client
|
||||
powSolver *PowSolver
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
|
||||
return &Client{
|
||||
Store: store,
|
||||
Auth: resolver,
|
||||
regular: trans.New(60 * time.Second),
|
||||
stream: trans.New(0),
|
||||
fallback: &http.Client{Timeout: 60 * time.Second},
|
||||
fallbackS: &http.Client{Timeout: 0},
|
||||
powSolver: NewPowSolver(config.WASMPath()),
|
||||
maxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
|
||||
payload := map[string]any{
|
||||
"password": strings.TrimSpace(acc.Password),
|
||||
"device_id": "deepseek_to_api",
|
||||
"os": "android",
|
||||
}
|
||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||
payload["email"] = email
|
||||
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
||||
payload["mobile"] = mobile
|
||||
payload["area_code"] = nil
|
||||
} else {
|
||||
return "", errors.New("missing email/mobile")
|
||||
}
|
||||
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if code != 0 {
|
||||
return "", fmt.Errorf("login failed: %v", resp["msg"])
|
||||
}
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
if intFrom(data["biz_code"]) != 0 {
|
||||
return "", fmt.Errorf("login failed: %v", data["biz_msg"])
|
||||
}
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
user, _ := bizData["user"].(map[string]any)
|
||||
token, _ := user["token"].(string)
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return "", errors.New("missing login token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
sessionID, _ := bizData["id"].(string)
|
||||
if sessionID != "" {
|
||||
return sessionID, nil
|
||||
}
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return "", errors.New("create session failed")
|
||||
}
|
||||
|
||||
func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
challenge, _ := bizData["challenge"].(map[string]any)
|
||||
answer, err := c.powSolver.Compute(ctx, challenge)
|
||||
if err != nil {
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
return BuildPowHeader(challenge, answer)
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return "", errors.New("get pow failed")
|
||||
}
|
||||
|
||||
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
|
||||
if err != nil {
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return resp, nil
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return nil, errors.New("completion failed")
|
||||
}
|
||||
|
||||
func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
|
||||
body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status == 0 {
|
||||
return nil, errors.New("request failed")
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := doer.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
payloadBytes, err := readResponseBody(resp)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if len(payloadBytes) > 0 {
|
||||
if err := json.Unmarshal(payloadBytes, &out); err != nil {
|
||||
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
|
||||
}
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := c.stream.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
return c.fallbackS.Do(req2)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) authHeaders(token string) map[string]string {
|
||||
headers := make(map[string]string, len(BaseHeaders)+1)
|
||||
for k, v := range BaseHeaders {
|
||||
headers[k] = v
|
||||
}
|
||||
headers["authorization"] = "Bearer " + token
|
||||
return headers
|
||||
}
|
||||
|
||||
func isTokenInvalid(status int, code int, msg string) bool {
|
||||
msg = strings.ToLower(msg)
|
||||
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||
return true
|
||||
}
|
||||
if code == 40001 || code == 40002 || code == 40003 {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||
}
|
||||
|
||||
func intFrom(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func readResponseBody(resp *http.Response) ([]byte, error) {
|
||||
encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
|
||||
var reader io.Reader = resp.Body
|
||||
switch encoding {
|
||||
case "gzip":
|
||||
gz, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gz.Close()
|
||||
reader = gz
|
||||
case "br":
|
||||
reader = brotli.NewReader(resp.Body)
|
||||
}
|
||||
return io.ReadAll(reader)
|
||||
}
|
||||
|
||||
func preview(b []byte) string {
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > 160 {
|
||||
return s[:160]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
if !onLine(scanner.Bytes()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
26
internal/deepseek/constants.go
Normal file
26
internal/deepseek/constants.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package deepseek
|
||||
|
||||
const (
|
||||
DeepSeekHost = "chat.deepseek.com"
|
||||
DeepSeekLoginURL = "https://chat.deepseek.com/api/v0/users/login"
|
||||
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
|
||||
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||
)
|
||||
|
||||
var BaseHeaders = map[string]string{
|
||||
"Host": "chat.deepseek.com",
|
||||
"User-Agent": "DeepSeek/1.6.11 Android/35",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"x-client-platform": "android",
|
||||
"x-client-version": "1.6.11",
|
||||
"x-client-locale": "zh_CN",
|
||||
"accept-charset": "UTF-8",
|
||||
}
|
||||
|
||||
const (
|
||||
KeepAliveTimeout = 5
|
||||
StreamIdleTimeout = 30
|
||||
MaxKeepaliveCount = 10
|
||||
)
|
||||
189
internal/deepseek/pow.go
Normal file
189
internal/deepseek/pow.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/config"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
type PowSolver struct {
|
||||
wasmPath string
|
||||
once sync.Once
|
||||
err error
|
||||
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
}
|
||||
|
||||
func NewPowSolver(wasmPath string) *PowSolver {
|
||||
return &PowSolver{wasmPath: wasmPath}
|
||||
}
|
||||
|
||||
func (p *PowSolver) init(ctx context.Context) error {
|
||||
p.once.Do(func() {
|
||||
wasmBytes, err := os.ReadFile(p.wasmPath)
|
||||
if err != nil {
|
||||
p.err = err
|
||||
return
|
||||
}
|
||||
p.runtime = wazero.NewRuntime(ctx)
|
||||
p.compiled, p.err = p.runtime.CompileModule(ctx, wasmBytes)
|
||||
})
|
||||
return p.err
|
||||
}
|
||||
|
||||
func (p *PowSolver) Compute(ctx context.Context, challenge map[string]any) (int64, error) {
|
||||
if err := p.init(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
algo, _ := challenge["algorithm"].(string)
|
||||
if algo != "DeepSeekHashV1" {
|
||||
return 0, errors.New("unsupported algorithm")
|
||||
}
|
||||
challengeStr, _ := challenge["challenge"].(string)
|
||||
salt, _ := challenge["salt"].(string)
|
||||
signature, _ := challenge["signature"].(string)
|
||||
targetPath, _ := challenge["target_path"].(string)
|
||||
_ = signature
|
||||
_ = targetPath
|
||||
|
||||
difficulty := toFloat64(challenge["difficulty"], 144000)
|
||||
expireAt := toInt64(challenge["expire_at"], 1680000000)
|
||||
prefix := salt + "_" + itoa(expireAt) + "_"
|
||||
|
||||
mod, err := p.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer mod.Close(ctx)
|
||||
|
||||
mem := mod.Memory()
|
||||
if mem == nil {
|
||||
return 0, errors.New("wasm memory missing")
|
||||
}
|
||||
stackFn := mod.ExportedFunction("__wbindgen_add_to_stack_pointer")
|
||||
allocFn := mod.ExportedFunction("__wbindgen_export_0")
|
||||
solveFn := mod.ExportedFunction("wasm_solve")
|
||||
if stackFn == nil || allocFn == nil || solveFn == nil {
|
||||
return 0, errors.New("required wasm exports missing")
|
||||
}
|
||||
|
||||
retPtrs, err := stackFn.Call(ctx, uint64(uint32(^uint32(15)))) // -16 i32
|
||||
if err != nil || len(retPtrs) == 0 {
|
||||
return 0, errors.New("stack alloc failed")
|
||||
}
|
||||
retptr := uint32(retPtrs[0])
|
||||
defer stackFn.Call(ctx, 16)
|
||||
|
||||
chPtr, chLen, err := writeUTF8(ctx, allocFn, mem, challengeStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
prefixPtr, prefixLen, err := writeUTF8(ctx, allocFn, mem, prefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if _, err := solveFn.Call(ctx,
|
||||
uint64(retptr),
|
||||
uint64(chPtr), uint64(chLen),
|
||||
uint64(prefixPtr), uint64(prefixLen),
|
||||
math.Float64bits(difficulty),
|
||||
); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
statusBytes, ok := mem.Read(retptr, 4)
|
||||
if !ok {
|
||||
return 0, errors.New("read status failed")
|
||||
}
|
||||
status := int32(binary.LittleEndian.Uint32(statusBytes))
|
||||
valueBytes, ok := mem.Read(retptr+8, 8)
|
||||
if !ok {
|
||||
return 0, errors.New("read value failed")
|
||||
}
|
||||
value := math.Float64frombits(binary.LittleEndian.Uint64(valueBytes))
|
||||
if status == 0 {
|
||||
return 0, errors.New("pow solve failed")
|
||||
}
|
||||
return int64(value), nil
|
||||
}
|
||||
|
||||
func writeUTF8(ctx context.Context, allocFn api.Function, mem api.Memory, text string) (uint32, uint32, error) {
|
||||
data := []byte(text)
|
||||
res, err := allocFn.Call(ctx, uint64(len(data)), 1)
|
||||
if err != nil || len(res) == 0 {
|
||||
return 0, 0, errors.New("alloc failed")
|
||||
}
|
||||
ptr := uint32(res[0])
|
||||
if !mem.Write(ptr, data) {
|
||||
return 0, 0, errors.New("mem write failed")
|
||||
}
|
||||
return ptr, uint32(len(data)), nil
|
||||
}
|
||||
|
||||
func BuildPowHeader(challenge map[string]any, answer int64) (string, error) {
|
||||
payload := map[string]any{
|
||||
"algorithm": challenge["algorithm"],
|
||||
"challenge": challenge["challenge"],
|
||||
"salt": challenge["salt"],
|
||||
"answer": answer,
|
||||
"signature": challenge["signature"],
|
||||
"target_path": challenge["target_path"],
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func toFloat64(v any, d float64) float64 {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return n
|
||||
case int:
|
||||
return float64(n)
|
||||
case int64:
|
||||
return float64(n)
|
||||
default:
|
||||
return d
|
||||
}
|
||||
}
|
||||
|
||||
func toInt64(v any, d int64) int64 {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int64(n)
|
||||
case int:
|
||||
return int64(n)
|
||||
case int64:
|
||||
return n
|
||||
default:
|
||||
return d
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(n int64) string {
|
||||
b, _ := json.Marshal(n)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func PreloadWASM(wasmPath string) {
|
||||
solver := NewPowSolver(wasmPath)
|
||||
if err := solver.init(context.Background()); err != nil {
|
||||
config.Logger.Warn("[WASM] preload failed", "error", err)
|
||||
return
|
||||
}
|
||||
config.Logger.Info("[WASM] module preloaded", "path", wasmPath)
|
||||
}
|
||||
59
internal/deepseek/transport/transport.go
Normal file
59
internal/deepseek/transport/transport.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
utls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
type Doer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
func New(timeout time.Duration) *Client {
|
||||
base := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ForceAttemptHTTP2: false,
|
||||
MaxIdleConns: 200,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext,
|
||||
DialTLSContext: safariTLSDialer(),
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
}
|
||||
return &Client{http: &http.Client{Timeout: timeout, Transport: base}}
|
||||
}
|
||||
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
return c.http.Do(req)
|
||||
}
|
||||
|
||||
func safariTLSDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
plainConn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(addr)
|
||||
uCfg := &utls.Config{
|
||||
ServerName: host,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
uConn := utls.UClient(plainConn, uCfg, utls.HelloSafari_Auto)
|
||||
err = uConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
_ = plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return uConn, nil
|
||||
}
|
||||
}
|
||||
105
internal/server/router.go
Normal file
105
internal/server/router.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/adapter/claude"
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/admin"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/webui"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
Resolver *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
Router http.Handler
|
||||
}
|
||||
|
||||
func NewApp() *App {
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
var dsClient *deepseek.Client
|
||||
resolver := auth.NewResolver(store, pool, func(ctx context.Context, acc config.Account) (string, error) {
|
||||
return dsClient.Login(ctx, acc)
|
||||
})
|
||||
dsClient = deepseek.NewClient(store, resolver)
|
||||
deepseek.PreloadWASM(config.WASMPath())
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(cors)
|
||||
r.Use(timeout(0))
|
||||
|
||||
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
})
|
||||
r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
||||
})
|
||||
openai.RegisterRoutes(r, openaiHandler)
|
||||
claude.RegisterRoutes(r, claudeHandler)
|
||||
r.Route("/admin", func(ar chi.Router) {
|
||||
admin.RegisterRoutes(ar, adminHandler)
|
||||
})
|
||||
webui.RegisterRoutes(r, webuiHandler)
|
||||
r.NotFound(func(w http.ResponseWriter, req *http.Request) {
|
||||
if strings.HasPrefix(req.URL.Path, "/admin/") && webuiHandler.HandleAdminFallback(w, req) {
|
||||
return
|
||||
}
|
||||
http.NotFound(w, req)
|
||||
})
|
||||
|
||||
return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r}
|
||||
}
|
||||
|
||||
func timeout(d time.Duration) func(http.Handler) http.Handler {
|
||||
if d <= 0 {
|
||||
return func(next http.Handler) http.Handler { return next }
|
||||
}
|
||||
return middleware.Timeout(d)
|
||||
}
|
||||
|
||||
func cors(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func WriteUnhandledError(w http.ResponseWriter, err error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"type": "api_error", "message": "Internal Server Error", "detail": err.Error()}})
|
||||
}
|
||||
224
internal/sse/parser.go
Normal file
224
internal/sse/parser.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ContentPart struct {
|
||||
Text string
|
||||
Type string
|
||||
}
|
||||
|
||||
var skipPatterns = []string{
|
||||
"quasi_status", "elapsed_secs", "token_usage", "pending_fragment", "conversation_mode",
|
||||
"fragments/-1/status", "fragments/-2/status", "fragments/-3/status",
|
||||
}
|
||||
|
||||
func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) {
|
||||
line := strings.TrimSpace(string(raw))
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
return nil, false, false
|
||||
}
|
||||
dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if dataStr == "[DONE]" {
|
||||
return nil, true, true
|
||||
}
|
||||
chunk := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
|
||||
return nil, false, false
|
||||
}
|
||||
return chunk, false, true
|
||||
}
|
||||
|
||||
func shouldSkipPath(path string) bool {
|
||||
if path == "response/search_status" {
|
||||
return true
|
||||
}
|
||||
for _, p := range skipPatterns {
|
||||
if strings.Contains(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, currentFragmentType string) ([]ContentPart, bool, string) {
|
||||
v, ok := chunk["v"]
|
||||
if !ok {
|
||||
return nil, false, currentFragmentType
|
||||
}
|
||||
path, _ := chunk["p"].(string)
|
||||
if shouldSkipPath(path) {
|
||||
return nil, false, currentFragmentType
|
||||
}
|
||||
if path == "response/status" {
|
||||
if s, ok := v.(string); ok && s == "FINISHED" {
|
||||
return nil, true, currentFragmentType
|
||||
}
|
||||
}
|
||||
newType := currentFragmentType
|
||||
if path == "response" {
|
||||
if arr, ok := v.([]any); ok {
|
||||
for _, it := range arr {
|
||||
m, ok := it.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m["p"] == "fragments" && m["o"] == "APPEND" {
|
||||
if frags, ok := m["v"].([]any); ok {
|
||||
for _, frag := range frags {
|
||||
fm, ok := frag.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t, _ := fm["type"].(string)
|
||||
t = strings.ToUpper(t)
|
||||
if t == "THINK" || t == "THINKING" {
|
||||
newType = "thinking"
|
||||
} else if t == "RESPONSE" {
|
||||
newType = "text"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
partType := "text"
|
||||
switch {
|
||||
case path == "response/thinking_content":
|
||||
partType = "thinking"
|
||||
case path == "response/content":
|
||||
partType = "text"
|
||||
case strings.Contains(path, "response/fragments") && strings.Contains(path, "/content"):
|
||||
partType = newType
|
||||
case path == "":
|
||||
if thinkingEnabled {
|
||||
partType = newType
|
||||
}
|
||||
}
|
||||
parts := make([]ContentPart, 0, 8)
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
if val == "FINISHED" && (path == "" || path == "status") {
|
||||
return nil, true, newType
|
||||
}
|
||||
if val != "" {
|
||||
parts = append(parts, ContentPart{Text: val, Type: partType})
|
||||
}
|
||||
case []any:
|
||||
pp, finished := extractContentRecursive(val, partType)
|
||||
if finished {
|
||||
return nil, true, newType
|
||||
}
|
||||
parts = append(parts, pp...)
|
||||
case map[string]any:
|
||||
resp := val
|
||||
if wrapped, ok := val["response"].(map[string]any); ok {
|
||||
resp = wrapped
|
||||
}
|
||||
if frags, ok := resp["fragments"].([]any); ok {
|
||||
for _, item := range frags {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t, _ := m["type"].(string)
|
||||
content, _ := m["content"].(string)
|
||||
t = strings.ToUpper(t)
|
||||
if t == "THINK" || t == "THINKING" {
|
||||
newType = "thinking"
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "thinking"})
|
||||
}
|
||||
} else if t == "RESPONSE" {
|
||||
newType = "text"
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "text"})
|
||||
}
|
||||
} else if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: partType})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts, false, newType
|
||||
}
|
||||
|
||||
func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bool) {
|
||||
parts := make([]ContentPart, 0, len(items))
|
||||
for _, it := range items {
|
||||
m, ok := it.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemPath, _ := m["p"].(string)
|
||||
itemV, hasV := m["v"]
|
||||
if !hasV {
|
||||
continue
|
||||
}
|
||||
if itemPath == "status" {
|
||||
if s, ok := itemV.(string); ok && s == "FINISHED" {
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
if shouldSkipPath(itemPath) {
|
||||
continue
|
||||
}
|
||||
if content, ok := m["content"].(string); ok && content != "" {
|
||||
typeName, _ := m["type"].(string)
|
||||
typeName = strings.ToUpper(typeName)
|
||||
switch typeName {
|
||||
case "THINK", "THINKING":
|
||||
parts = append(parts, ContentPart{Text: content, Type: "thinking"})
|
||||
case "RESPONSE":
|
||||
parts = append(parts, ContentPart{Text: content, Type: "text"})
|
||||
default:
|
||||
parts = append(parts, ContentPart{Text: content, Type: defaultType})
|
||||
}
|
||||
continue
|
||||
}
|
||||
partType := defaultType
|
||||
if strings.Contains(itemPath, "thinking") {
|
||||
partType = "thinking"
|
||||
} else if strings.Contains(itemPath, "content") || itemPath == "response" || itemPath == "fragments" {
|
||||
partType = "text"
|
||||
}
|
||||
switch v := itemV.(type) {
|
||||
case string:
|
||||
if v != "" && v != "FINISHED" {
|
||||
parts = append(parts, ContentPart{Text: v, Type: partType})
|
||||
}
|
||||
case []any:
|
||||
for _, inner := range v {
|
||||
switch x := inner.(type) {
|
||||
case map[string]any:
|
||||
ct, _ := x["content"].(string)
|
||||
if ct == "" {
|
||||
continue
|
||||
}
|
||||
typeName, _ := x["type"].(string)
|
||||
typeName = strings.ToUpper(typeName)
|
||||
if typeName == "THINK" || typeName == "THINKING" {
|
||||
parts = append(parts, ContentPart{Text: ct, Type: "thinking"})
|
||||
} else if typeName == "RESPONSE" {
|
||||
parts = append(parts, ContentPart{Text: ct, Type: "text"})
|
||||
} else {
|
||||
parts = append(parts, ContentPart{Text: ct, Type: partType})
|
||||
}
|
||||
case string:
|
||||
if x != "" {
|
||||
parts = append(parts, ContentPart{Text: x, Type: partType})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts, false
|
||||
}
|
||||
|
||||
func IsCitation(text string) bool {
|
||||
return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:"))
|
||||
}
|
||||
49
internal/sse/parser_test.go
Normal file
49
internal/sse/parser_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package sse
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseDeepSeekSSELine(t *testing.T) {
|
||||
chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"你好"}`))
|
||||
if !ok || done {
|
||||
t.Fatalf("expected parsed chunk")
|
||||
}
|
||||
if chunk["v"] != "你好" {
|
||||
t.Fatalf("unexpected chunk: %#v", chunk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekSSELineDone(t *testing.T) {
|
||||
_, done, ok := ParseDeepSeekSSELine([]byte(`data: [DONE]`))
|
||||
if !ok || !done {
|
||||
t.Fatalf("expected done signal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentSimple(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text")
|
||||
if finished {
|
||||
t.Fatal("expected unfinished")
|
||||
}
|
||||
if len(parts) != 1 || parts[0].Text != "hello" || parts[0].Type != "text" {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentThinking(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{"p": "response/thinking_content", "v": "think"}, true, "thinking")
|
||||
if finished {
|
||||
t.Fatal("expected unfinished")
|
||||
}
|
||||
if len(parts) != 1 || parts[0].Type != "thinking" {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCitation(t *testing.T) {
|
||||
if !IsCitation("[citation:1] abc") {
|
||||
t.Fatal("expected citation true")
|
||||
}
|
||||
if IsCitation("normal text") {
|
||||
t.Fatal("expected citation false")
|
||||
}
|
||||
}
|
||||
127
internal/util/messages.go
Normal file
127
internal/util/messages.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`)
|
||||
|
||||
const ClaudeDefaultModel = "claude-sonnet-4-20250514"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
type block struct {
|
||||
Role string
|
||||
Text string
|
||||
}
|
||||
processed := make([]block, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role, _ := m["role"].(string)
|
||||
text := normalizeContent(m["content"])
|
||||
processed = append(processed, block{Role: role, Text: text})
|
||||
}
|
||||
if len(processed) == 0 {
|
||||
return ""
|
||||
}
|
||||
merged := make([]block, 0, len(processed))
|
||||
for _, msg := range processed {
|
||||
if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role {
|
||||
merged[len(merged)-1].Text += "\n\n" + msg.Text
|
||||
continue
|
||||
}
|
||||
merged = append(merged, msg)
|
||||
}
|
||||
parts := make([]string, 0, len(merged))
|
||||
for i, m := range merged {
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>")
|
||||
case "user", "system":
|
||||
if i > 0 {
|
||||
parts = append(parts, "<|User|>"+m.Text)
|
||||
} else {
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
default:
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
}
|
||||
out := strings.Join(parts, "")
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
}
|
||||
|
||||
func normalizeContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, item := range x {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "text" {
|
||||
if txt, ok := m["text"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[string]any {
|
||||
messages, _ := claudeReq["messages"].([]any)
|
||||
model, _ := claudeReq["model"].(string)
|
||||
if model == "" {
|
||||
model = ClaudeDefaultModel
|
||||
}
|
||||
mapping := store.ClaudeMapping()
|
||||
dsModel := mapping["fast"]
|
||||
if dsModel == "" {
|
||||
dsModel = "deepseek-chat"
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") {
|
||||
if slow := mapping["slow"]; slow != "" {
|
||||
dsModel = slow
|
||||
}
|
||||
}
|
||||
convertedMessages := make([]any, 0, len(messages)+1)
|
||||
if system, ok := claudeReq["system"].(string); ok && system != "" {
|
||||
convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system})
|
||||
}
|
||||
convertedMessages = append(convertedMessages, messages...)
|
||||
|
||||
out := map[string]any{"model": dsModel, "messages": convertedMessages}
|
||||
for _, k := range []string{"temperature", "top_p", "stream"} {
|
||||
if v, ok := claudeReq[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
if stopSeq, ok := claudeReq["stop_sequences"]; ok {
|
||||
out["stop"] = stopSeq
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func EstimateTokens(text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
n := len([]rune(text)) / 4
|
||||
if n < 1 {
|
||||
return 1
|
||||
}
|
||||
return n
|
||||
}
|
||||
69
internal/util/messages_test.go
Normal file
69
internal/util/messages_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestMessagesPrepareBasic(t *testing.T) {
|
||||
messages := []map[string]any{{"role": "user", "content": "Hello"}}
|
||||
got := MessagesPrepare(messages)
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
if got != "Hello" {
|
||||
t.Fatalf("unexpected prompt: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareRoles(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "system", "content": "You are helper"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !contains(got, "<|Assistant|>") {
|
||||
t.Fatalf("expected assistant marker in %q", got)
|
||||
}
|
||||
if !contains(got, "<|User|>") {
|
||||
t.Fatalf("expected user marker in %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeek(t *testing.T) {
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-20250514-slow",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
|
||||
"system": "You are helpful",
|
||||
"stream": true,
|
||||
}
|
||||
out := ConvertClaudeToDeepSeek(req, store)
|
||||
if out["model"] == "" {
|
||||
t.Fatal("expected mapped model")
|
||||
}
|
||||
msgs, ok := out["messages"].([]any)
|
||||
if !ok || len(msgs) == 0 {
|
||||
t.Fatal("expected messages")
|
||||
}
|
||||
first, _ := msgs[0].(map[string]any)
|
||||
if first["role"] != "system" {
|
||||
t.Fatalf("expected first message system, got %#v", first)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(sub) == 0 || (len(s) > 0 && (indexOf(s, sub) >= 0)))
|
||||
}
|
||||
|
||||
func indexOf(s, sub string) int {
|
||||
for i := 0; i+len(sub) <= len(s); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
69
internal/util/toolcalls.go
Normal file
69
internal/util/toolcalls.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
|
||||
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
}
|
||||
m := toolCallPattern.FindStringSubmatch(text)
|
||||
if len(m) < 2 {
|
||||
return nil
|
||||
}
|
||||
payload := "{" + `"tool_calls":[` + m[1] + "]}"
|
||||
var obj struct {
|
||||
ToolCalls []ParsedToolCall `json:"tool_calls"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(payload), &obj); err != nil {
|
||||
return nil
|
||||
}
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(obj.ToolCalls))
|
||||
for _, tc := range obj.ToolCalls {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[tc.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if tc.Input == nil {
|
||||
tc.Input = map[string]any{}
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
33
internal/util/toolcalls_test.go
Normal file
33
internal/util/toolcalls_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseToolCalls(t *testing.T) {
|
||||
text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Name != "search" {
|
||||
t.Fatalf("unexpected tool name: %s", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectUnknown(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected 0 calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
|
||||
if len(formatted) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(formatted))
|
||||
}
|
||||
fn, _ := formatted[0]["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", fn)
|
||||
}
|
||||
}
|
||||
81
internal/webui/handler.go
Normal file
81
internal/webui/handler.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package webui
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
const welcomeHTML = `<!DOCTYPE html>
|
||||
<html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>DS2API</title>
|
||||
<style>body{font-family:Inter,system-ui,sans-serif;background:#030712;color:#f9fafb;display:flex;min-height:100vh;align-items:center;justify-content:center;margin:0}a{color:#f59e0b;text-decoration:none}main{max-width:700px;padding:24px;text-align:center}h1{font-size:48px;margin:0 0 12px}.links{display:flex;gap:16px;justify-content:center;margin-top:20px;flex-wrap:wrap}</style>
|
||||
</head><body><main><h1>DS2API</h1><p>DeepSeek to OpenAI & Claude Compatible API</p><div class="links"><a href="/admin">管理面板</a><a href="/v1/models">API 状态</a><a href="https://github.com/CJackHwang/ds2api" target="_blank">GitHub</a></div></main></body></html>`
|
||||
|
||||
type Handler struct {
|
||||
StaticDir string
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{StaticDir: config.StaticAdminDir()}
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/", h.index)
|
||||
r.Get("/admin", h.admin)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleAdminFallback(w http.ResponseWriter, r *http.Request) bool {
|
||||
if r.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(r.URL.Path, "/admin/") {
|
||||
return false
|
||||
}
|
||||
h.admin(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) index(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(welcomeHTML))
|
||||
}
|
||||
|
||||
func (h *Handler) admin(w http.ResponseWriter, r *http.Request) {
|
||||
if fi, err := os.Stat(h.StaticDir); err != nil || !fi.IsDir() {
|
||||
http.Error(w, "WebUI not built. Run `cd webui && npm run build` first.", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
path := strings.TrimPrefix(r.URL.Path, "/admin")
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
if path != "" && strings.Contains(path, ".") {
|
||||
full := filepath.Join(h.StaticDir, filepath.Clean(path))
|
||||
if !strings.HasPrefix(full, h.StaticDir) {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if _, err := os.Stat(full); err == nil {
|
||||
if strings.HasPrefix(path, "assets/") {
|
||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
} else {
|
||||
w.Header().Set("Cache-Control", "no-store, must-revalidate")
|
||||
}
|
||||
http.ServeFile(w, r, full)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
index := filepath.Join(h.StaticDir, "index.html")
|
||||
if _, err := os.Stat(index); err != nil {
|
||||
http.Error(w, "index.html not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Cache-Control", "no-store, must-revalidate")
|
||||
http.ServeFile(w, r, index)
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
{
|
||||
"version": 2,
|
||||
"buildCommand": "bash scripts/build-webui.sh",
|
||||
"builds": [
|
||||
{
|
||||
"src": "api/index.go",
|
||||
"use": "@vercel/go"
|
||||
}
|
||||
],
|
||||
"rewrites": [
|
||||
{
|
||||
"source": "/(.*)",
|
||||
"destination": "/app.py"
|
||||
"destination": "/api/index.go"
|
||||
}
|
||||
],
|
||||
"headers": [
|
||||
|
||||
Reference in New Issue
Block a user