From 2acf58590a9bd4256916f1cfcad61833855da059 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Tue, 17 Feb 2026 19:51:53 +0800 Subject: [PATCH 01/52] ci: publish docker image archives in release assets --- .github/workflows/release-artifacts.yml | 54 +++++++++++++++++++++++++ Dockerfile | 4 +- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 5fd262f..bf46e92 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -7,6 +7,7 @@ on: permissions: contents: write + packages: write jobs: build-and-upload: @@ -72,6 +73,59 @@ jobs: rm -rf "${STAGE}" done + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/${{ github.repository }} + tags: | + type=raw,value=${{ github.event.release.tag_name }} + type=raw,value=latest + + - name: Build and Push Docker Image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + push: true + platforms: linux/amd64,linux/arm64 + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + - name: Export Docker image archives for release assets + run: | + set -euo pipefail + TAG="${{ github.event.release.tag_name }}" + + docker buildx build \ + --platform linux/amd64 \ + --output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \ + . + + docker buildx build \ + --platform linux/arm64 \ + --output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \ + . + + gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar" + gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar" + + - name: Generate checksums + run: | + set -euo pipefail (cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt) - name: Upload Release Assets diff --git a/Dockerfile b/Dockerfile index 3199cfb..a67dfd1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,10 +8,12 @@ RUN npm run build FROM golang:1.24 AS go-builder WORKDIR /app +ARG TARGETOS=linux +ARG TARGETARCH=amd64 COPY go.mod go.sum* ./ RUN go mod download COPY . . -RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o /out/ds2api ./cmd/ds2api +RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o /out/ds2api ./cmd/ds2api FROM debian:bookworm-slim WORKDIR /app From 89e93a1674743762526a8c934f03fd5a00ecd9d5 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 00:38:38 +0800 Subject: [PATCH 02/52] feat: Improve configuration loading robustness, add Vercel-specific fallbacks, and update documentation for `config.json` best practices. --- .env.example | 3 ++ API.en.md | 23 ++++++++ API.md | 23 ++++++++ DEPLOY.en.md | 59 ++++++++++++++++++-- DEPLOY.md | 59 ++++++++++++++++++-- README.MD | 42 +++++++++++++-- README.en.md | 42 +++++++++++++-- internal/config/config.go | 99 ++++++++++++++++++++++++++++++---- internal/config/config_test.go | 51 ++++++++++++++++++ 9 files changed, 370 insertions(+), 31 deletions(-) diff --git a/.env.example b/.env.example index 21a4d2a..d63f133 100644 --- a/.env.example +++ b/.env.example @@ -52,6 +52,9 @@ DS2API_ADMIN_KEY=admin # Option C: Base64 encoded JSON (recommended for Vercel env var) # DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0= +# +# Generate from local config.json: +# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" # --------------------------------------------------------------- # Paths (optional) diff --git a/API.en.md b/API.en.md index e570dee..1203e12 100644 --- a/API.en.md +++ b/API.en.md @@ -9,6 +9,7 @@ This document describes the actual behavior of the current Go codebase. ## Table of Contents - [Basics](#basics) +- [Configuration Best Practice](#configuration-best-practice) - [Authentication](#authentication) - [Route Index](#route-index) - [Health Endpoints](#health-endpoints) @@ -31,6 +32,28 @@ This document describes the actual behavior of the current Go codebase. --- +## Configuration Best Practice + +Use `config.json` as the single source of truth: + +```bash +cp config.example.json config.json +# Edit config.json (keys/accounts) +``` + +Use it per deployment mode: + +- Local run: read `config.json` directly +- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON` + +```bash +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +``` + +For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page. + +--- + ## Authentication ### Business Endpoints (`/v1/*`, `/anthropic/*`) diff --git a/API.md b/API.md index 6be7f65..f57f0a8 100644 --- a/API.md +++ b/API.md @@ -9,6 +9,7 @@ ## 目录 - [基础信息](#基础信息) +- [配置最佳实践](#配置最佳实践) - [鉴权规则](#鉴权规则) - [路由总览](#路由总览) - [健康检查](#健康检查) @@ -31,6 +32,28 @@ --- +## 配置最佳实践 + +推荐把 `config.json` 作为唯一配置源: + +```bash +cp config.example.json config.json +# 编辑 config.json(keys/accounts) +``` + +按部署方式使用: + +- 本地运行:直接读取 `config.json` +- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON` + +```bash +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +``` + +Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。 + +--- + ## 鉴权规则 ### 业务接口(`/v1/*`、`/anthropic/*`) diff --git a/DEPLOY.en.md b/DEPLOY.en.md index b7caf8c..8a62c98 100644 --- a/DEPLOY.en.md +++ b/DEPLOY.en.md @@ -33,6 +33,17 @@ Config source (choose one): - **File**: `config.json` (recommended for local/Docker) - **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64) +Unified recommendation (best practice): + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Use `config.json` as the single source of truth: +- Local run: read `config.json` directly +- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it + --- ## 1. Local Run @@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api ### 2.1 Basic Steps ```bash -# Copy and edit environment +# Copy env template cp .env.example .env -# Edit .env, at minimum set: + +# Generate single-line Base64 from config.json +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# Edit .env and set: # DS2API_ADMIN_KEY=your-admin-key -# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]} +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} # Start docker-compose up -d @@ -167,15 +182,49 @@ If container logs look normal but the admin panel is unreachable, check these fi 1. **Fork** the repo to your GitHub account 2. **Import** the project on Vercel -3. **Set environment variables** (at minimum): +3. **Set environment variables** (minimum required: one variable): | Variable | Description | | --- | --- | | `DS2API_ADMIN_KEY` | Admin key (required) | - | `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) | + | `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) | 4. **Deploy** +### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes) + +If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page. + +Recommended: in repo root, copy the template first and fill your real accounts: + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it: + +```bash +# Run in repo root +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +echo "$DS2API_CONFIG_JSON" +``` + +If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables: + +```text +DS2API_ADMIN_KEY=replace-with-a-strong-secret +DS2API_CONFIG_JSON= +``` + +Optional but recommended (for WebUI one-click Vercel sync): + +```text +VERCEL_TOKEN=your-vercel-token +VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx +VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts +``` + ### 3.2 Optional Environment Variables | Variable | Description | Default | diff --git a/DEPLOY.md b/DEPLOY.md index b7fbf9a..e5b0630 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -33,6 +33,17 @@ - **文件方式**:`config.json`(推荐本地/Docker 使用) - **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码) +统一建议(最优实践): + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +建议把 `config.json` 作为唯一配置源: +- 本地运行:直接读 `config.json` +- Docker / Vercel:从 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量 + --- ## 一、本地运行 @@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api ### 2.1 基本步骤 ```bash -# 复制并编辑环境变量 +# 复制环境变量模板 cp .env.example .env -# 编辑 .env,至少设置: + +# 从 config.json 生成单行 Base64 +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 编辑 .env(请改成你的强密码),设置: # DS2API_ADMIN_KEY=your-admin-key -# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]} +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} # 启动 docker-compose up -d @@ -167,15 +182,49 @@ healthcheck: 1. **Fork 仓库**到你的 GitHub 账号 2. **在 Vercel 上导入项目** -3. **配置环境变量**(至少设置以下两项): +3. **配置环境变量**(最少只需设置以下一项): | 变量 | 说明 | | --- | --- | | `DS2API_ADMIN_KEY` | 管理密钥(必填) | - | `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(必填) | + | `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(可选,建议) | 4. **部署** +### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错) + +如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置,再在「Vercel 同步」里写回环境变量。 + +建议先在仓库目录复制示例配置,再按实际账号填写: + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +不要在 Vercel 面板里手写复杂 JSON,建议本地生成 Base64 后粘贴: + +```bash +# 在仓库根目录执行 +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +echo "$DS2API_CONFIG_JSON" +``` + +如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置: + +```text +DS2API_ADMIN_KEY=请替换为强密码 +DS2API_CONFIG_JSON=上一步生成的一整行 Base64 +``` + +可选但推荐(用于 WebUI 一键同步 Vercel 配置): + +```text +VERCEL_TOKEN=你的 Vercel Token +VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx +VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空 +``` + ### 3.2 可选环境变量 | 变量 | 说明 | 默认值 | diff --git a/README.MD b/README.MD index b438b75..3517a55 100644 --- a/README.MD +++ b/README.MD @@ -88,6 +88,19 @@ flowchart LR ## 快速开始 +### 通用第一步(所有部署方式) + +把 `config.json` 作为唯一配置源(推荐做法): + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +后续部署建议: +- 本地运行:直接读取 `config.json` +- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量 + ### 方式一:本地运行 **前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时) @@ -112,14 +125,20 @@ go run ./cmd/ds2api ### 方式二:Docker 运行 ```bash -# 1. 配置环境变量 +# 1. 准备环境变量文件 cp .env.example .env -# 编辑 .env -# 2. 启动 +# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64) +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 3. 编辑 .env,设置: +# DS2API_ADMIN_KEY=请替换为强密码 +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} + +# 4. 启动 docker-compose up -d -# 3. 查看日志 +# 5. 查看日志 docker-compose logs -f ``` @@ -129,9 +148,22 @@ docker-compose logs -f 1. Fork 仓库到自己的 GitHub 2. 在 Vercel 上导入项目 -3. 配置环境变量(至少设置 `DS2API_ADMIN_KEY` 和 `DS2API_CONFIG_JSON`) +3. 配置环境变量(最少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`) 4. 部署 +建议先在仓库目录复制模板并填写: + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +推荐:先本地把 `config.json` 转成 Base64,再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误: + +```bash +base64 < config.json | tr -d '\n' +``` + > **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。 详细部署说明请参阅 [部署指南](DEPLOY.md)。 diff --git a/README.en.md b/README.en.md index bbad73b..d1a91a1 100644 --- a/README.en.md +++ b/README.en.md @@ -88,6 +88,19 @@ In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4 ## Quick Start +### Universal First Step (all deployment modes) + +Use `config.json` as the single source of truth (recommended): + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Recommended per deployment mode: +- Local run: read `config.json` directly +- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON` + ### Option 1: Local Run **Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally) @@ -112,14 +125,20 @@ Default URL: `http://localhost:5001` ### Option 2: Docker ```bash -# 1. Configure environment +# 1. Prepare env file cp .env.example .env -# Edit .env -# 2. Start +# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64) +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 3. Edit .env and set: +# DS2API_ADMIN_KEY=replace-with-a-strong-secret +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} + +# 4. Start docker-compose up -d -# 3. View logs +# 5. View logs docker-compose logs -f ``` @@ -129,9 +148,22 @@ Rebuild after updates: `docker-compose up -d --build` 1. Fork this repo to your GitHub account 2. Import the project on Vercel -3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`) +3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`) 4. Deploy +Recommended first step in repo root: + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes: + +```bash +base64 < config.json | tr -d '\n' +``` + > **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling. For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md). diff --git a/internal/config/config.go b/internal/config/config.go index 691df6d..b4058c6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "errors" + "fmt" "log/slog" "os" "path/filepath" @@ -101,17 +102,29 @@ func (c *Config) UnmarshalJSON(b []byte) error { for k, v := range raw { switch k { case "keys": - _ = json.Unmarshal(v, &c.Keys) + if err := json.Unmarshal(v, &c.Keys); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "accounts": - _ = json.Unmarshal(v, &c.Accounts) + if err := json.Unmarshal(v, &c.Accounts); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "claude_mapping": - _ = json.Unmarshal(v, &c.ClaudeMapping) + if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "claude_model_mapping": - _ = json.Unmarshal(v, &c.ClaudeModelMap) + if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "_vercel_sync_hash": - _ = json.Unmarshal(v, &c.VercelSyncHash) + if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "_vercel_sync_time": - _ = json.Unmarshal(v, &c.VercelSyncTime) + if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } default: var anyVal any if err := json.Unmarshal(v, &anyVal); err == nil { @@ -233,30 +246,94 @@ func loadConfig() (Config, bool, error) { content, err := os.ReadFile(ConfigPath()) if err != nil { + if IsVercel() { + // Vercel one-click deploy may start without a writable/present config file. + // Keep an in-memory config so users can bootstrap via WebUI then sync env. + return Config{}, true, nil + } return Config{}, false, err } var cfg Config if err := json.Unmarshal(content, &cfg); err != nil { return Config{}, false, err } + if IsVercel() { + // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors. + return cfg, true, nil + } return cfg, false, nil } func parseConfigString(raw string) (Config, error) { var cfg Config - if err := json.Unmarshal([]byte(raw), &cfg); err == nil { - return cfg, nil + candidates := []string{raw} + if normalized := normalizeConfigInput(raw); normalized != raw { + candidates = append(candidates, normalized) } - decoded, err := base64.StdEncoding.DecodeString(raw) + for _, candidate := range candidates { + if err := json.Unmarshal([]byte(candidate), &cfg); err == nil { + return cfg, nil + } + } + + base64Input := candidates[len(candidates)-1] + decoded, err := decodeConfigBase64(base64Input) if err != nil { - return Config{}, err + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err) } if err := json.Unmarshal(decoded, &cfg); err != nil { - return Config{}, err + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err) } return cfg, nil } +func normalizeConfigInput(raw string) string { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return normalized + } + for { + changed := false + if len(normalized) >= 2 { + first := normalized[0] + last := normalized[len(normalized)-1] + if (first == '"' && last == '"') || (first == '\'' && last == '\'') { + normalized = strings.TrimSpace(normalized[1 : len(normalized)-1]) + changed = true + } + } + if strings.HasPrefix(strings.ToLower(normalized), "base64:") { + normalized = strings.TrimSpace(normalized[len("base64:"):]) + changed = true + } + if !changed { + break + } + } + return strings.TrimSpace(normalized) +} + +func decodeConfigBase64(raw string) ([]byte, error) { + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, enc := range encodings { + decoded, err := enc.DecodeString(raw) + if err == nil { + return decoded, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("base64 decode failed") +} + func (s *Store) Snapshot() Config { s.mu.RLock() defer s.mu.RUnlock() diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 58a8a2a..a409fd7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/base64" "strings" "testing" ) @@ -70,3 +71,53 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) t.Fatalf("expected find by old identifier alias") } } + +func TestLoadStoreRejectsInvalidFieldType(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":"not-array","accounts":[]}`) + store := LoadStore() + if len(store.Keys()) != 0 || len(store.Accounts()) != 0 { + t.Fatalf("expected empty store when config type is invalid") + } +} + +func TestParseConfigStringSupportsQuotedBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString(`"base64:` + b64 + `"`) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringSupportsRawURLBase64(t *testing.T) { + rawJSON := `{"keys":["k-url"],"accounts":[]}` + b64 := base64.RawURLEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString(b64) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k-url" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) { + t.Setenv("VERCEL", "1") + t.Setenv("DS2API_CONFIG_JSON", "") + t.Setenv("CONFIG_JSON", "") + t.Setenv("DS2API_CONFIG_PATH", "testdata/does-not-exist.json") + + cfg, fromEnv, err := loadConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !fromEnv { + t.Fatalf("expected fromEnv=true for vercel fallback") + } + if len(cfg.Keys) != 0 || len(cfg.Accounts) != 0 { + t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts)) + } +} From 19289c9008a5157b10dd4f8abf3ca8fefb90b83c Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 00:54:54 +0800 Subject: [PATCH 03/52] refactor: Modularize OpenAI message normalization and prompt building, enhancing `MessagesPrepare` to support additional content types and tool call formatting. --- internal/adapter/openai/handler.go | 20 +- internal/adapter/openai/message_normalize.go | 192 ++++++++++++++++++ .../adapter/openai/message_normalize_test.go | 121 +++++++++++ internal/adapter/openai/prompt_build.go | 12 ++ internal/adapter/openai/prompt_build_test.go | 80 ++++++++ internal/adapter/openai/vercel_stream.go | 6 +- internal/util/messages.go | 16 +- internal/util/messages_test.go | 27 +++ 8 files changed, 449 insertions(+), 25 deletions(-) create mode 100644 internal/adapter/openai/message_normalize.go create mode 100644 internal/adapter/openai/message_normalize_test.go create mode 100644 internal/adapter/openai/prompt_build.go create mode 100644 internal/adapter/openai/prompt_build_test.go diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index d0a2f1d..1602cf6 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -86,12 +86,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { return } - messages := normalizeMessages(messagesRaw) - toolNames := []string{} - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, toolNames = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -405,17 +400,6 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt } } -func normalizeMessages(raw []any) []map[string]any { - out := make([]map[string]any, 0, len(raw)) - for _, item := range raw { - m, ok := item.(map[string]any) - if ok { - out = append(out, m) - } - } - return out -} - func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { toolSchemas := make([]string, 0, len(tools)) names := make([]string, 0, len(tools)) @@ -444,7 +428,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, if len(toolSchemas) == 0 { return messages, names } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }" + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error." for i := range messages { if messages[i]["role"] == "system" { diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go new file mode 100644 index 0000000..3ebd1e7 --- /dev/null +++ b/internal/adapter/openai/message_normalize.go @@ -0,0 +1,192 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" +) + +func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + msg, ok := item.(map[string]any) + if !ok { + continue + } + role := strings.ToLower(strings.TrimSpace(asString(msg["role"]))) + switch role { + case "assistant": + content := normalizeOpenAIContentForPrompt(msg["content"]) + toolCalls := formatAssistantToolCallsForPrompt(msg) + combined := joinNonEmpty(content, toolCalls) + if combined == "" { + continue + } + out = append(out, map[string]any{ + "role": "assistant", + "content": combined, + }) + case "tool", "function": + out = append(out, map[string]any{ + "role": "user", + "content": formatToolResultForPrompt(msg), + }) + case "user", "system": + out = append(out, map[string]any{ + "role": role, + "content": normalizeOpenAIContentForPrompt(msg["content"]), + }) + default: + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + continue + } + if role == "" { + role = "user" + } + out = append(out, map[string]any{ + "role": role, + "content": content, + }) + } + } + return out +} + +func formatAssistantToolCallsForPrompt(msg map[string]any) string { + entries := make([]string, 0) + if calls, ok := msg["tool_calls"].([]any); ok { + for i, item := range calls { + call, ok := item.(map[string]any) + if !ok { + continue + } + id := strings.TrimSpace(asString(call["id"])) + if id == "" { + id = fmt.Sprintf("call_%d", i+1) + } + name := strings.TrimSpace(asString(call["name"])) + args := "" + + if fn, ok := call["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) + } + if name == "" { + name = "unknown" + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["input"]) + } + if args == "" { + args = "{}" + } + entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args)) + } + } + + if legacy, ok := msg["function_call"].(map[string]any); ok { + name := strings.TrimSpace(asString(legacy["name"])) + if name == "" { + name = "unknown" + } + args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"]) + if args == "" { + args = "{}" + } + entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args)) + } + + return strings.Join(entries, "\n\n") +} + +func formatToolResultForPrompt(msg map[string]any) string { + toolCallID := strings.TrimSpace(asString(msg["tool_call_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(asString(msg["id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + + name := strings.TrimSpace(asString(msg["name"])) + if name == "" { + name = "unknown" + } + + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + content = "null" + } + + return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content) +} + +func normalizeOpenAIContentForPrompt(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + t := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + if t != "text" && t != "output_text" && t != "input_text" { + continue + } + if text := asString(m["text"]); text != "" { + parts = append(parts, text) + continue + } + if text := asString(m["content"]); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + default: + return marshalToPromptString(v) + } +} + +func normalizeOpenAIArgumentsForPrompt(v any) string { + switch x := v.(type) { + case string: + return strings.TrimSpace(x) + default: + return marshalToPromptString(v) + } +} + +func marshalToPromptString(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} + +func asString(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func joinNonEmpty(parts ...string) string { + nonEmpty := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p) == "" { + continue + } + nonEmpty = append(nonEmpty, p) + } + return strings.Join(nonEmpty, "\n\n") +} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go new file mode 100644 index 0000000..bb648d3 --- /dev/null +++ b/internal/adapter/openai/message_normalize_test.go @@ -0,0 +1,121 @@ +package openai + +import ( + "strings" + "testing" + + "ds2api/internal/util" +) + +func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) { + raw := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "content": nil, + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": "{\"temp\":18}", + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + if len(normalized) != 4 { + t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) + } + assistantContent, _ := normalized[2]["content"].(string) + if !strings.Contains(assistantContent, "tool_call_id: call_1") || + !strings.Contains(assistantContent, "function.name: get_weather") || + !strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") { + t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent) + } + toolContent, _ := normalized[3]["content"].(string) + if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") { + t.Fatalf("tool result not serialized correctly: %q", toolContent) + } + + prompt := util.MessagesPrepare(normalized) + if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") { + t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_2", + "name": "get_weather", + "content": map[string]any{ + "temp": 18, + "condition": "sunny", + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) { + t.Fatalf("expected serialized object in tool content, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_3", + "name": "read_file", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "output_text", "text": "line-2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "line-1\nline-2") { + t.Fatalf("expected joined text blocks, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "function", + "tool_call_id": "call_4", + "name": "legacy_tool", + "content": map[string]any{ + "ok": true, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + if len(normalized) != 1 { + t.Fatalf("expected one normalized message, got %d", len(normalized)) + } + if normalized[0]["role"] != "user" { + t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"]) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { + t.Fatalf("unexpected normalized function-role content: %q", got) + } +} diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go new file mode 100644 index 0000000..a7bbc92 --- /dev/null +++ b/internal/adapter/openai/prompt_build.go @@ -0,0 +1,12 @@ +package openai + +import "ds2api/internal/util" + +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { + messages := normalizeOpenAIMessagesForPrompt(messagesRaw) + toolNames := []string{} + if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { + messages, toolNames = injectToolPrompt(messages, tools) + } + return util.MessagesPrepare(messages), toolNames +} diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go new file mode 100644 index 0000000..1833860 --- /dev/null +++ b/internal/adapter/openai/prompt_build_test.go @@ -0,0 +1,80 @@ +package openai + +import ( + "strings" + "testing" +) + +func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) { + messages := []any{ + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": map[string]any{"temp": 18, "condition": "sunny"}, + }, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "description": "Get weather", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools) + if len(toolNames) != 1 || toolNames[0] != "get_weather" { + t.Fatalf("unexpected tool names: %#v", toolNames) + } + if !strings.Contains(finalPrompt, "tool_call_id: call_1") || + !strings.Contains(finalPrompt, "function.name: get_weather") || + !strings.Contains(finalPrompt, "Tool result:") || + !strings.Contains(finalPrompt, `"condition":"sunny"`) { + t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt) + } +} + +func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) { + messages := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "请调用工具"}, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search docs", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools) + if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") { + t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) + } + if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") { + t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 653f3cf..85c9cd8 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -68,11 +68,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque return } - messages := normalizeMessages(messagesRaw) - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, _ = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) + finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { diff --git a/internal/util/messages.go b/internal/util/messages.go index 19f2948..fcc9484 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -1,6 +1,8 @@ package util import ( + "encoding/json" + "fmt" "regexp" "strings" @@ -68,15 +70,25 @@ func normalizeContent(v any) string { if !ok { continue } - if m["type"] == "text" { + typeStr, _ := m["type"].(string) + typeStr = strings.ToLower(strings.TrimSpace(typeStr)) + if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { if txt, ok := m["text"].(string); ok { parts = append(parts, txt) + continue + } + if txt, ok := m["content"].(string); ok { + parts = append(parts, txt) } } } return strings.Join(parts, "\n") default: - return "" + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) } } diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go index 30b8cc0..776853b 100644 --- a/internal/util/messages_test.go +++ b/internal/util/messages_test.go @@ -33,6 +33,33 @@ func TestMessagesPrepareRoles(t *testing.T) { } } +func TestMessagesPrepareObjectContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": map[string]any{"temp": 18, "ok": true}}, + } + got := MessagesPrepare(messages) + if !contains(got, `"temp":18`) || !contains(got, `"ok":true`) { + t.Fatalf("expected serialized object content, got %q", got) + } +} + +func TestMessagesPrepareArrayTextVariants(t *testing.T) { + messages := []map[string]any{ + { + "role": "user", + "content": []any{ + map[string]any{"type": "output_text", "text": "line1"}, + map[string]any{"type": "input_text", "text": "line2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + got := MessagesPrepare(messages) + if got != "line1\nline2" { + t.Fatalf("unexpected content from text variants: %q", got) + } +} + func TestConvertClaudeToDeepSeek(t *testing.T) { store := config.LoadStore() req := map[string]any{ From 7beeea57796b7b6ddbb690f694e99175ff5a3d3d Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 16:10:35 +0800 Subject: [PATCH 04/52] feat: Implement streaming incremental tool call deltas with a new tool sieve and standalone parser. --- api/chat-stream.js | 256 +++++++-- api/chat-stream.test.js | 3 +- api/helpers/stream-tool-sieve.js | 487 ++++++++++++++++-- api/helpers/stream-tool-sieve.test.js | 49 +- internal/adapter/openai/handler.go | 55 +- .../adapter/openai/handler_toolcall_test.go | 116 ++++- internal/adapter/openai/tool_sieve.go | 429 ++++++++++++++- internal/util/toolcalls.go | 30 ++ internal/util/toolcalls_test.go | 13 + 9 files changed, 1324 insertions(+), 114 deletions(-) diff --git a/api/chat-stream.js b/api/chat-stream.js index aa92b17..309c473 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -1,11 +1,13 @@ 'use strict'; +const crypto = require('crypto'); + const { extractToolNames, createToolSieveState, processToolSieveChunk, flushToolSieve, - parseToolCalls, + parseStandaloneToolCalls, formatOpenAIStreamToolCalls, } = require('./helpers/stream-tool-sieve'); @@ -90,16 +92,49 @@ module.exports = async function handler(req, res) { return; } const releaseLease = createLeaseReleaser(req, leaseID); + const upstreamController = new AbortController(); + let clientClosed = false; + let reader = null; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + upstreamController.abort(); + if (reader && typeof reader.cancel === 'function') { + Promise.resolve(reader.cancel()).catch(() => {}); + } + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); try { - const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { - method: 'POST', - headers: { - ...BASE_HEADERS, - authorization: `Bearer ${deepseekToken}`, - 'x-ds-pow-response': powHeader, - }, - body: JSON.stringify(completionPayload), - }); + let completionRes; + try { + completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { + method: 'POST', + headers: { + ...BASE_HEADERS, + authorization: `Bearer ${deepseekToken}`, + 'x-ds-pow-response': powHeader, + }, + body: JSON.stringify(completionPayload), + signal: upstreamController.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + return; + } + throw err; + } + if (clientClosed) { + return; + } if (!completionRes.ok || !completionRes.body) { const detail = await safeReadText(completionRes); @@ -124,12 +159,16 @@ module.exports = async function handler(req, res) { const toolSieveEnabled = toolNames.length > 0; const toolSieveState = createToolSieveState(); let toolCallsEmitted = false; + const streamToolCallIDs = new Map(); const decoder = new TextDecoder(); - const reader = completionRes.body.getReader(); + reader = completionRes.body.getReader(); let buffered = ''; let ended = false; const sendFrame = (obj) => { + if (clientClosed || res.writableEnded || res.destroyed) { + return; + } res.write(`data: ${JSON.stringify(obj)}\n\n`); if (typeof res.flush === 'function') { res.flush(); @@ -156,7 +195,11 @@ module.exports = async function handler(req, res) { return; } ended = true; - const detected = parseToolCalls(outputText, toolNames); + if (clientClosed || res.writableEnded || res.destroyed) { + await releaseLease(); + return; + } + const detected = parseStandaloneToolCalls(outputText, toolNames); if (detected.length > 0 && !toolCallsEmitted) { toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); @@ -179,14 +222,22 @@ module.exports = async function handler(req, res) { choices: [{ delta: {}, index: 0, finish_reason: reason }], usage: buildUsage(finalPrompt, thinkingText, outputText), }); - res.write('data: [DONE]\n\n'); + if (!res.writableEnded && !res.destroyed) { + res.write('data: [DONE]\n\n'); + } await releaseLease(); - res.end(); + if (!res.writableEnded && !res.destroyed) { + res.end(); + } }; try { // eslint-disable-next-line no-constant-condition while (true) { + if (clientClosed) { + await finish('stop'); + return; + } const { value, done } = await reader.read(); if (done) { break; @@ -245,6 +296,11 @@ module.exports = async function handler(req, res) { } const events = processToolSieveChunk(toolSieveState, p.text, toolNames); for (const evt of events) { + if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); + continue; + } if (evt.type === 'tool_calls') { toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); @@ -259,10 +315,16 @@ module.exports = async function handler(req, res) { } } await finish('stop'); - } catch (_err) { + } catch (err) { + if (clientClosed || isAbortError(err)) { + await finish('stop'); + return; + } await finish('stop'); } } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); await releaseLease(); } }; @@ -656,6 +718,55 @@ function buildUsage(prompt, thinking, output) { }; } +function formatIncrementalToolCallDeltas(deltas, idStore) { + if (!Array.isArray(deltas) || deltas.length === 0) { + return []; + } + const out = []; + for (const d of deltas) { + if (!d || typeof d !== 'object') { + continue; + } + const index = Number.isInteger(d.index) ? d.index : 0; + const id = ensureStreamToolCallID(idStore, index); + const item = { + index, + id, + type: 'function', + }; + const fn = {}; + if (asString(d.name)) { + fn.name = asString(d.name); + } + if (typeof d.arguments === 'string' && d.arguments !== '') { + fn.arguments = d.arguments; + } + if (Object.keys(fn).length > 0) { + item.function = fn; + } + out.push(item); + } + return out; +} + +function ensureStreamToolCallID(idStore, index) { + const key = Number.isInteger(index) ? index : 0; + const existing = idStore.get(key); + if (existing) { + return existing; + } + const next = `call_${newCallID()}`; + idStore.set(key, next); + return next; +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + function estimateTokens(text) { const t = asString(text); if (!t) { @@ -667,44 +778,92 @@ function estimateTokens(text) { async function proxyToGo(req, res, rawBody) { const url = buildInternalGoURL(req); - - const upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withContentType: true }), - body: rawBody, - }); - - res.statusCode = upstream.status; - upstream.headers.forEach((value, key) => { - if (key.toLowerCase() === 'content-length') { + const controller = new AbortController(); + let clientClosed = false; + const markClientClosed = () => { + if (clientClosed) { return; } - res.setHeader(key, value); - }); + clientClosed = true; + controller.abort(); + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); - if (!upstream.body || typeof upstream.body.getReader !== 'function') { - const bytes = Buffer.from(await upstream.arrayBuffer()); - res.end(bytes); - return; - } - - const reader = upstream.body.getReader(); try { - // eslint-disable-next-line no-constant-condition - while (true) { - const { value, done } = await reader.read(); - if (done) { - break; + let upstream; + try { + upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withContentType: true }), + body: rawBody, + signal: controller.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + if (!res.writableEnded) { + res.end(); + } + return; } - if (value && value.length > 0) { - res.write(Buffer.from(value)); - if (typeof res.flush === 'function') { - res.flush(); + throw err; + } + if (clientClosed) { + if (!res.writableEnded) { + res.end(); + } + return; + } + + res.statusCode = upstream.status; + upstream.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (!upstream.body || typeof upstream.body.getReader !== 'function') { + const bytes = Buffer.from(await upstream.arrayBuffer()); + res.end(bytes); + return; + } + + const reader = upstream.body.getReader(); + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + break; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.length > 0) { + res.write(Buffer.from(value)); + if (typeof res.flush === 'function') { + res.flush(); + } } } + if (!res.writableEnded) { + res.end(); + } + } catch (err) { + if (!isAbortError(err) && !res.writableEnded) { + res.end(); + } } - res.end(); - } catch (_err) { + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); if (!res.writableEnded) { res.end(); } @@ -762,6 +921,13 @@ function asString(v) { return String(v).trim(); } +function isAbortError(err) { + if (!err || typeof err !== 'object') { + return false; + } + return err.name === 'AbortError' || err.code === 'ABORT_ERR'; +} + module.exports.__test = { parseChunkForContent, extractContentRecursive, diff --git a/api/chat-stream.test.js b/api/chat-stream.test.js index b347342..c849f7c 100644 --- a/api/chat-stream.test.js +++ b/api/chat-stream.test.js @@ -49,12 +49,13 @@ test('parseChunkForContent + sieve does not leak suspicious prefix in split tool events.push(...flushToolSieve(state, ['read_file'])); const hasToolCalls = events.some((evt) => evt.type === 'tool_calls' && evt.calls && evt.calls.length > 0); + const hasToolDeltas = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas && evt.deltas.length > 0); const leakedText = events .filter((evt) => evt.type === 'text' && evt.text) .map((evt) => evt.text) .join(''); - assert.equal(hasToolCalls, true); + assert.equal(hasToolCalls || hasToolDeltas, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index 3ced63d..8b586aa 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -2,6 +2,7 @@ const crypto = require('crypto'); const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; +const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; function extractToolNames(tools) { if (!Array.isArray(tools) || tools.length === 0) { @@ -26,9 +27,25 @@ function createToolSieveState() { pending: '', capture: '', capturing: false, + hasMeaningfulText: false, + toolNameSent: false, + toolName: '', + toolArgsStart: -1, + toolArgsSent: -1, + toolArgsString: false, + toolArgsDone: false, }; } +function resetIncrementalToolState(state) { + state.toolNameSent = false; + state.toolName = ''; + state.toolArgsStart = -1; + state.toolArgsSent = -1; + state.toolArgsString = false; + state.toolArgsDone = false; +} + function processToolSieveChunk(state, chunk, toolNames) { if (!state) { return []; @@ -44,13 +61,31 @@ function processToolSieveChunk(state, chunk, toolNames) { state.capture += state.pending; state.pending = ''; } - const consumed = consumeToolCapture(state.capture, toolNames); + const deltas = buildIncrementalToolDeltas(state); + if (deltas.length > 0) { + events.push({ type: 'tool_call_deltas', deltas }); + } + const consumed = consumeToolCapture(state, toolNames); if (!consumed.ready) { + if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { + if (hasMeaningfulText(state.capture)) { + state.hasMeaningfulText = true; + } + events.push({ type: 'text', text: state.capture }); + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + continue; + } break; } state.capture = ''; state.capturing = false; + resetIncrementalToolState(state); if (consumed.prefix) { + if (hasMeaningfulText(consumed.prefix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { @@ -70,11 +105,15 @@ function processToolSieveChunk(state, chunk, toolNames) { if (start >= 0) { const prefix = state.pending.slice(0, start); if (prefix) { + if (hasMeaningfulText(prefix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: prefix }); } state.capture = state.pending.slice(start); state.pending = ''; state.capturing = true; + resetIncrementalToolState(state); continue; } @@ -83,6 +122,9 @@ function processToolSieveChunk(state, chunk, toolNames) { break; } state.pending = hold; + if (hasMeaningfulText(safe)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: safe }); } return events; @@ -94,24 +136,37 @@ function flushToolSieve(state, toolNames) { } const events = processToolSieveChunk(state, '', toolNames); if (state.capturing) { - const consumed = consumeToolCapture(state.capture, toolNames); + const consumed = consumeToolCapture(state, toolNames); if (consumed.ready) { if (consumed.prefix) { + if (hasMeaningfulText(consumed.prefix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { events.push({ type: 'tool_calls', calls: consumed.calls }); } if (consumed.suffix) { + if (hasMeaningfulText(consumed.suffix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: consumed.suffix }); } } else if (state.capture) { - // Incomplete captured tool JSON at stream end: suppress raw capture. + if (hasMeaningfulText(state.capture)) { + state.hasMeaningfulText = true; + } + events.push({ type: 'text', text: state.capture }); } state.capture = ''; state.capturing = false; + resetIncrementalToolState(state); } if (state.pending) { + if (hasMeaningfulText(state.pending)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: state.pending }); state.pending = ''; } @@ -159,7 +214,8 @@ function findToolSegmentStart(s) { return start >= 0 ? start : keyIdx; } -function consumeToolCapture(captured, toolNames) { +function consumeToolCapture(state, toolNames) { + const captured = state.capture; if (!captured) { return { ready: false, prefix: '', calls: [], suffix: '' }; } @@ -176,25 +232,361 @@ function consumeToolCapture(captured, toolNames) { if (!obj.ok) { return { ready: false, prefix: '', calls: [], suffix: '' }; } - const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames); - if (parsed.length === 0) { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. + const prefixPart = captured.slice(0, start); + const suffixPart = captured.slice(obj.end); + if (!state.toolNameSent && (state.hasMeaningfulText || hasMeaningfulText(prefixPart) || hasMeaningfulText(suffixPart))) { return { ready: true, - prefix: captured.slice(0, start), + prefix: captured, calls: [], - suffix: captured.slice(obj.end), + suffix: '', + }; + } + const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); + if (parsed.length === 0) { + if (state.toolNameSent) { + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + if (state.toolNameSent) { + if (parsed.length > 1) { + return { + ready: true, + prefix: prefixPart, + calls: parsed.slice(1), + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, }; } return { ready: true, - prefix: captured.slice(0, start), + prefix: prefixPart, calls: parsed, - suffix: captured.slice(obj.end), + suffix: suffixPart, }; } +function buildIncrementalToolDeltas(state) { + const captured = state.capture || ''; + if (!captured || state.hasMeaningfulText) { + return []; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return []; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0 || hasMeaningfulText(captured.slice(0, start))) { + return []; + } + const callStart = findFirstToolCallObjectStart(captured, keyIdx); + if (callStart < 0) { + return []; + } + + const deltas = []; + if (!state.toolName) { + const name = extractToolCallName(captured, callStart); + if (!name) { + return []; + } + state.toolName = name; + } + + if (state.toolArgsStart < 0) { + const args = findToolCallArgsStart(captured, callStart); + if (args) { + state.toolArgsString = Boolean(args.stringMode); + state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start; + state.toolArgsSent = state.toolArgsStart; + } + } + if (!state.toolNameSent) { + if (state.toolArgsStart < 0) { + return []; + } + state.toolNameSent = true; + deltas.push({ index: 0, name: state.toolName }); + } + if (state.toolArgsStart < 0 || state.toolArgsDone) { + return deltas; + } + const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString); + if (!progress) { + return deltas; + } + if (progress.end > state.toolArgsSent) { + deltas.push({ + index: 0, + arguments: captured.slice(state.toolArgsSent, progress.end), + }); + state.toolArgsSent = progress.end; + } + if (progress.complete) { + state.toolArgsDone = true; + } + return deltas; +} + +function findFirstToolCallObjectStart(text, keyIdx) { + const arrStart = findToolCallsArrayStart(text, keyIdx); + if (arrStart < 0) { + return -1; + } + const i = skipSpaces(text, arrStart + 1); + if (i >= text.length || text[i] !== '{') { + return -1; + } + return i; +} + +function findToolCallsArrayStart(text, keyIdx) { + let i = keyIdx + 'tool_calls'.length; + while (i < text.length && text[i] !== ':') { + i += 1; + } + if (i >= text.length) { + return -1; + } + i = skipSpaces(text, i + 1); + if (i >= text.length || text[i] !== '[') { + return -1; + } + return i; +} + +function extractToolCallName(text, callStart) { + let valueStart = findObjectFieldValueStart(text, callStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return ''; + } + valueStart = findObjectFieldValueStart(text, fnStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + return ''; + } + } + const parsed = parseJSONStringLiteral(text, valueStart); + if (!parsed) { + return ''; + } + return parsed.value; +} + +function findToolCallArgsStart(text, callStart) { + const keys = ['input', 'arguments', 'args', 'parameters', 'params']; + let valueStart = findObjectFieldValueStart(text, callStart, keys); + if (valueStart < 0) { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return null; + } + valueStart = findObjectFieldValueStart(text, fnStart, keys); + if (valueStart < 0) { + return null; + } + } + if (valueStart >= text.length) { + return null; + } + const ch = text[valueStart]; + if (ch === '{' || ch === '[') { + return { start: valueStart, stringMode: false }; + } + if (ch === '"') { + return { start: valueStart, stringMode: true }; + } + return null; +} + +function scanToolCallArgsProgress(text, start, stringMode) { + if (start < 0 || start > text.length) { + return null; + } + if (stringMode) { + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { end: i, complete: true }; + } + } + return { end: text.length, complete: false }; + } + if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) { + return null; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{' || ch === '[') { + depth += 1; + continue; + } + if (ch === '}' || ch === ']') { + depth -= 1; + if (depth === 0) { + return { end: i + 1, complete: true }; + } + } + } + return { end: text.length, complete: false }; +} + +function findObjectFieldValueStart(text, objStart, keys) { + if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') { + return -1; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = objStart; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + if (depth === 1) { + const parsed = parseJSONStringLiteral(text, i); + if (!parsed) { + return -1; + } + let j = skipSpaces(text, parsed.end); + if (j >= text.length || text[j] !== ':') { + i = parsed.end - 1; + continue; + } + j = skipSpaces(text, j + 1); + if (j >= text.length) { + return -1; + } + if (keys.includes(parsed.value)) { + return j; + } + i = j - 1; + continue; + } + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + break; + } + } + } + return -1; +} + +function findFunctionObjectStart(text, callStart) { + const valueStart = findObjectFieldValueStart(text, callStart, ['function']); + if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') { + return -1; + } + return valueStart; +} + +function parseJSONStringLiteral(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '"') { + return null; + } + let out = ''; + let escaped = false; + for (let i = start + 1; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + out += ch; + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { value: out, end: i + 1 }; + } + out += ch; + } + return null; +} + +function skipSpaces(text, i) { + let idx = i; + while (idx < text.length) { + const ch = text[idx]; + if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') { + idx += 1; + continue; + } + break; + } + return idx; +} + function extractJSONObjectFrom(text, start) { if (!text || start < 0 || start >= text.length || text[start] !== '{') { return { ok: false, end: 0 }; @@ -251,26 +643,35 @@ function parseToolCalls(text, toolNames) { if (parsed.length === 0) { return []; } - const allowed = new Set((toolNames || []).filter(Boolean)); - const out = []; - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (allowed.size > 0 && !allowed.has(tc.name)) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); + return filterToolCalls(parsed, toolNames); +} + +function parseStandaloneToolCalls(text, toolNames) { + const trimmed = toStringSafe(text); + if (!trimmed) { + return []; } - if (out.length === 0 && parsed.length > 0) { - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); + const candidates = [trimmed]; + if (trimmed.startsWith('```') && trimmed.endsWith('```')) { + const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); } } - return out; + for (const candidate of candidates) { + const c = toStringSafe(candidate); + if (!c) { + continue; + } + if (!c.startsWith('{') && !c.startsWith('[')) { + continue; + } + const parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + return filterToolCalls(parsed, toolNames); + } + } + return []; } function buildToolCallCandidates(text) { @@ -432,6 +833,33 @@ function parseToolCallInput(v) { return {}; } +function filterToolCalls(parsed, toolNames) { + const allowed = new Set((toolNames || []).filter(Boolean)); + const out = []; + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + if (allowed.size > 0 && !allowed.has(tc.name)) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + if (out.length === 0 && parsed.length > 0) { + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + } + return out; +} + +function hasMeaningfulText(text) { + return toStringSafe(text) !== ''; +} + function formatOpenAIStreamToolCalls(calls) { if (!Array.isArray(calls) || calls.length === 0) { return []; @@ -473,5 +901,6 @@ module.exports = { processToolSieveChunk, flushToolSieve, parseToolCalls, + parseStandaloneToolCalls, formatOpenAIStreamToolCalls, }; diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js index 47b3100..ad1dc0b 100644 --- a/api/helpers/stream-tool-sieve.test.js +++ b/api/helpers/stream-tool-sieve.test.js @@ -9,6 +9,7 @@ const { processToolSieveChunk, flushToolSieve, parseToolCalls, + parseStandaloneToolCalls, } = require('./stream-tool-sieve'); function runSieve(chunks, toolNames) { @@ -73,6 +74,15 @@ test('parseToolCalls supports fenced json and function.arguments string payload' assert.deepEqual(calls[0].input, { path: 'README.md' }); }); +test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => { + const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。'; + const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'; + const mixedCalls = parseStandaloneToolCalls(mixed, ['read_file']); + const standaloneCalls = parseStandaloneToolCalls(standalone, ['read_file']); + assert.equal(mixedCalls.length, 0); + assert.equal(standaloneCalls.length, 1); +}); + test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => { const events = runSieve( [ @@ -84,13 +94,14 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con ); const leakedText = collectText(events); const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); - assert.equal(hasToolCall, true); + const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); + assert.equal(hasToolCall || hasToolDelta, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); assert.equal(leakedText.includes('后置正文C。'), true); }); -test('sieve drops invalid tool json body while preserving surrounding text', () => { +test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => { const events = runSieve( [ '前置正文D。', @@ -104,18 +115,18 @@ test('sieve drops invalid tool json body while preserving surrounding text', () assert.equal(hasToolCall, false); assert.equal(leakedText.includes('前置正文D。'), true); assert.equal(leakedText.includes('后置正文E。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); }); -test('sieve suppresses incomplete captured tool json on stream finalize', () => { +test('sieve flushes incomplete captured tool json as text on stream finalize', () => { const events = runSieve( ['前置正文F。', '{"tool_calls":[{"name":"read_file"'], ['read_file'], ); const leakedText = collectText(events); assert.equal(leakedText.includes('前置正文F。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); - assert.equal(leakedText.includes('{'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + assert.equal(leakedText.includes('{'), true); }); test('sieve keeps plain text intact in tool mode when no tool call appears', () => { @@ -128,3 +139,29 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', () assert.equal(hasToolCall, false); assert.equal(leakedText, '你好,这是普通文本回复。请继续。'); }); + +test('sieve emits incremental tool_call_deltas for split arguments payload', () => { + const state = createToolSieveState(); + const first = processToolSieveChunk( + state, + '{"tool_calls":[{"name":"read_file","input":{"path":"READ', + ['read_file'], + ); + const second = processToolSieveChunk( + state, + 'ME.MD","mode":"head"}}]}', + ['read_file'], + ); + const tail = flushToolSieve(state, ['read_file']); + const events = [...first, ...second, ...tail]; + const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas'); + assert.equal(deltaEvents.length > 0, true); + const merged = deltaEvents.flatMap((evt) => evt.deltas || []); + const hasName = merged.some((d) => d.name === 'read_file'); + const argsJoined = merged + .map((d) => d.arguments || '') + .join(''); + assert.equal(hasName, true); + assert.equal(argsJoined.includes('"path":"README.MD"'), true); + assert.equal(argsJoined.includes('"mode":"head"'), true); +}); diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 1602cf6..962e450 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -11,6 +11,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "ds2api/internal/auth" "ds2api/internal/config" @@ -134,7 +135,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - detected := util.ParseToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCalls(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} if thinkingEnabled && finalThinking != "" { @@ -188,6 +189,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt bufferToolContent := len(toolNames) > 0 var toolSieve toolStreamSieveState toolCallsEmitted := false + streamToolCallIDs := map[int]string{} initialType := "text" if thinkingEnabled { initialType = "thinking" @@ -220,7 +222,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt finalize := func(finishReason string) { finalThinking := thinking.String() finalText := text.String() - detected := util.ParseToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCalls(finalText, toolNames) if len(detected) > 0 && !toolCallsEmitted { finishReason = "tool_calls" delta := map[string]any{ @@ -352,6 +354,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt // Keep thinking delta only frame. } for _, evt := range events { + if len(evt.ToolCallDeltas) > 0 { + toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), + } + if !firstChunkSent { + tcDelta["role"] = "assistant" + firstChunkSent = true + } + newChoices = append(newChoices, map[string]any{ + "delta": tcDelta, + "index": 0, + }) + continue + } if len(evt.ToolCalls) > 0 { toolCallsEmitted = true tcDelta := map[string]any{ @@ -441,6 +458,40 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, return messages, names } +func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any { + if len(deltas) == 0 { + return nil + } + out := make([]map[string]any, 0, len(deltas)) + for _, d := range deltas { + if d.Name == "" && d.Arguments == "" { + continue + } + callID, ok := ids[d.Index] + if !ok || callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + ids[d.Index] = callID + } + item := map[string]any{ + "index": d.Index, + "id": callID, + "type": "function", + } + fn := map[string]any{} + if d.Name != "" { + fn["name"] = d.Name + } + if d.Arguments != "" { + fn["arguments"] = d.Arguments + } + if len(fn) > 0 { + item["function"] = fn + } + out = append(out, item) + } + return out +} + func writeOpenAIError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]any{ "error": map[string]any{ diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index f9c44dd..30197d7 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string { return "" } +func streamToolCallArgumentChunks(frames []map[string]any) []string { + out := make([]string, 0, 4) + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + toolCalls, _ := delta["tool_calls"].([]any) + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + fn, _ := tcm["function"].(map[string]any) + if args, ok := fn["arguments"].(string); ok && args != "" { + out = append(out, args) + } + } + } + } + return out +} + func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -190,6 +210,37 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { } } +func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"下面是示例:"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, "示例") || !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected embedded example to pass through as text, got %q", content) + } +} + func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -391,11 +442,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String()) - } - if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -412,8 +460,11 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } - if streamFinishReason(frames) != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + if !strings.Contains(got, `"tool_calls"`) { + t.Fatalf("expected mixed stream to preserve embedded tool_calls example text, got=%q", got) + } + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String()) } } @@ -495,16 +546,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) { } } } - got := strings.ToLower(content.String()) - if strings.Contains(got, "tool_calls") { - t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String()) - } - if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") { + got := content.String() + if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") { t.Fatalf("expected pre/post plain text to remain, got=%q", content.String()) } + if !strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got) + } } -func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) { +func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, @@ -533,7 +584,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing. } } } - if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") { - t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String()) + if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") { + t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String()) + } +} + +func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`, + `data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + argChunks := streamToolCallArgumentChunks(frames) + if len(argChunks) < 2 { + t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String()) + } + joined := strings.Join(argChunks, "") + if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) { + t.Fatalf("unexpected merged arguments stream: %q", joined) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index d1a9014..d890314 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -7,14 +7,39 @@ import ( ) type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool + pending strings.Builder + capture strings.Builder + capturing bool + hasMeaningfulText bool + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool } type toolStreamEvent struct { - Content string - ToolCalls []util.ParsedToolCall + Content string + ToolCalls []util.ParsedToolCall + ToolCallDeltas []toolCallDelta +} + +type toolCallDelta struct { + Index int + Name string + Arguments string +} + +const toolSieveCaptureLimit = 8 * 1024 + +func (s *toolStreamSieveState) resetIncrementalToolState() { + s.toolNameSent = false + s.toolName = "" + s.toolArgsStart = -1 + s.toolArgsSent = -1 + s.toolArgsString = false + s.toolArgsDone = false } func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { @@ -32,13 +57,31 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capture.WriteString(state.pending.String()) state.pending.Reset() } - prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames) + if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { + events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) + } + prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) if !ready { + if state.capture.Len() > toolSieveCaptureLimit { + content := state.capture.String() + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + if strings.TrimSpace(content) != "" { + state.hasMeaningfulText = true + } + events = append(events, toolStreamEvent{Content: content}) + continue + } break } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() if prefix != "" { + if strings.TrimSpace(prefix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: prefix}) } if len(calls) > 0 { @@ -58,11 +101,15 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames if start >= 0 { prefix := pending[:start] if prefix != "" { + if strings.TrimSpace(prefix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: prefix}) } state.pending.Reset() state.capture.WriteString(pending[start:]) state.capturing = true + state.resetIncrementalToolState() continue } @@ -72,6 +119,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames } state.pending.Reset() state.pending.WriteString(hold) + if strings.TrimSpace(safe) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: safe}) } @@ -84,25 +134,42 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea } events := processToolSieveChunk(state, "", toolNames) if state.capturing { - consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames) + consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) if ready { if consumedPrefix != "" { + if strings.TrimSpace(consumedPrefix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: consumedPrefix}) } if len(consumedCalls) > 0 { events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) } if consumedSuffix != "" { + if strings.TrimSpace(consumedSuffix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: consumedSuffix}) } } else { - // Incomplete captured tool JSON at stream end: suppress raw capture. + content := state.capture.String() + if content != "" { + if strings.TrimSpace(content) != "" { + state.hasMeaningfulText = true + } + events = append(events, toolStreamEvent{Content: content}) + } } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() } if state.pending.Len() > 0 { - events = append(events, toolStreamEvent{Content: state.pending.String()}) + content := state.pending.String() + if strings.TrimSpace(content) != "" { + state.hasMeaningfulText = true + } + events = append(events, toolStreamEvent{Content: content}) state.pending.Reset() } return events @@ -154,7 +221,8 @@ func findToolSegmentStart(s string) int { return keyIdx } -func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { +func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { + captured := state.capture.String() if captured == "" { return "", nil, "", false } @@ -171,13 +239,25 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal if !ok { return "", nil, "", false } - parsed := util.ParseToolCalls(obj, toolNames) - if len(parsed) == 0 { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. - return captured[:start], nil, captured[end:], true + prefixPart := captured[:start] + suffixPart := captured[end:] + if !state.toolNameSent && (state.hasMeaningfulText || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "") { + return captured, nil, "", true } - return captured[:start], parsed, captured[end:], true + parsed := util.ParseStandaloneToolCalls(obj, toolNames) + if len(parsed) == 0 { + if state.toolNameSent { + return prefixPart, nil, suffixPart, true + } + return captured, nil, "", true + } + if state.toolNameSent { + if len(parsed) > 1 { + return prefixPart, parsed[1:], suffixPart, true + } + return prefixPart, nil, suffixPart, true + } + return prefixPart, parsed, suffixPart, true } func extractJSONObjectFrom(text string, start int) (string, int, bool) { @@ -221,3 +301,320 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) { } return "", 0, false } + +func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + captured := state.capture.String() + if captured == "" || state.hasMeaningfulText { + return nil + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return nil + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 || strings.TrimSpace(captured[:start]) != "" { + return nil + } + callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) + if !ok { + return nil + } + deltas := make([]toolCallDelta, 0, 2) + if state.toolName == "" { + name, ok := extractToolCallName(captured, callStart) + if !ok || name == "" { + return nil + } + state.toolName = name + } + if state.toolArgsStart < 0 { + argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart) + if ok { + state.toolArgsString = stringMode + if stringMode { + state.toolArgsStart = argsStart + 1 + } else { + state.toolArgsStart = argsStart + } + state.toolArgsSent = state.toolArgsStart + } + } + if !state.toolNameSent { + if state.toolArgsStart < 0 { + return nil + } + state.toolNameSent = true + deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName}) + } + if state.toolArgsStart < 0 || state.toolArgsDone { + return deltas + } + end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString) + if !ok { + return deltas + } + if end > state.toolArgsSent { + deltas = append(deltas, toolCallDelta{ + Index: 0, + Arguments: captured[state.toolArgsSent:end], + }) + state.toolArgsSent = end + } + if complete { + state.toolArgsDone = true + } + return deltas +} + +func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return -1, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return -1, false + } + return i, true +} + +func findToolCallsArrayStart(text string, keyIdx int) (int, bool) { + i := keyIdx + len("tool_calls") + for i < len(text) && text[i] != ':' { + i++ + } + if i >= len(text) { + return -1, false + } + i = skipSpaces(text, i+1) + if i >= len(text) || text[i] != '[' { + return -1, false + } + return i, true +} + +func extractToolCallName(text string, callStart int) (string, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return "", false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + return "", false + } + } + name, _, ok := parseJSONStringLiteral(text, valueStart) + if !ok { + return "", false + } + return name, true +} + +func findToolCallArgsStart(text string, callStart int) (int, bool, bool) { + keys := []string{"input", "arguments", "args", "parameters", "params"} + valueStart, ok := findObjectFieldValueStart(text, callStart, keys) + if !ok { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return -1, false, false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, keys) + if !ok { + return -1, false, false + } + } + if valueStart >= len(text) { + return -1, false, false + } + ch := text[valueStart] + if ch == '{' || ch == '[' { + return valueStart, false, true + } + if ch == '"' { + return valueStart, true, true + } + return -1, false, false +} + +func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) { + if start < 0 || start > len(text) { + return 0, false, false + } + if stringMode { + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return i, true, true + } + } + return len(text), false, true + } + if start >= len(text) { + return start, false, false + } + if text[start] != '{' && text[start] != '[' { + return 0, false, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' || ch == '[' { + depth++ + continue + } + if ch == '}' || ch == ']' { + depth-- + if depth == 0 { + return i + 1, true, true + } + } + } + return len(text), false, true +} + +func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) { + if objStart < 0 || objStart >= len(text) || text[objStart] != '{' { + return 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := objStart; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + if depth == 1 { + key, end, ok := parseJSONStringLiteral(text, i) + if !ok { + return 0, false + } + j := skipSpaces(text, end) + if j >= len(text) || text[j] != ':' { + i = end - 1 + continue + } + j = skipSpaces(text, j+1) + if j >= len(text) { + return 0, false + } + if containsKey(keys, key) { + return j, true + } + i = j - 1 + continue + } + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + break + } + } + } + return 0, false +} + +func findFunctionObjectStart(text string, callStart int) (int, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"}) + if !ok || valueStart >= len(text) || text[valueStart] != '{' { + return -1, false + } + return valueStart, true +} + +func parseJSONStringLiteral(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '"' { + return "", 0, false + } + var b strings.Builder + escaped := false + for i := start + 1; i < len(text); i++ { + ch := text[i] + if escaped { + b.WriteByte(ch) + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return b.String(), i + 1, true + } + b.WriteByte(ch) + } + return "", 0, false +} + +func containsKey(keys []string, value string) bool { + for _, k := range keys { + if k == value { + return true + } + } + return false +} + +func skipSpaces(text string, i int) int { + for i < len(text) { + switch text[i] { + case ' ', '\t', '\n', '\r': + i++ + default: + return i + } + } + return i +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go index 9b9d4e6..4760546 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls.go @@ -33,6 +33,36 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { return nil } + return filterToolCalls(parsed, availableToolNames) +} + +func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return nil + } + candidates := []string{trimmed} + if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") { + if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 { + candidates = append(candidates, strings.TrimSpace(m[1])) + } + } + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") { + continue + } + if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 { + return filterToolCalls(parsed, availableToolNames) + } + } + return nil +} + +func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall { allowed := map[string]struct{}{} for _, name := range availableToolNames { allowed[name] = struct{}{} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 8c44320..8a29a18 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -62,3 +62,16 @@ func TestFormatOpenAIToolCalls(t *testing.T) { t.Fatalf("unexpected function name: %#v", fn) } } + +func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) { + mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls) + } + + standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseStandaloneToolCalls(standalone, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected standalone parser to match, got %#v", calls) + } +} From deec72416e8eb3c28b41ecc8ba6ac4f07cc35907 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 16:51:30 +0800 Subject: [PATCH 05/52] test: Introduce comprehensive edge case tests across multiple modules and refine tool call and OpenAI handler logic. --- api/helpers/stream-tool-sieve.js | 84 +++++++++++++------ api/helpers/stream-tool-sieve.test.js | 18 ++++ .../adapter/openai/handler_toolcall_test.go | 77 ++++++++++++++++- internal/adapter/openai/tool_sieve.go | 84 +++++++++++++------ internal/util/toolcalls.go | 27 ++++++ internal/util/toolcalls_test.go | 7 ++ 6 files changed, 242 insertions(+), 55 deletions(-) diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index 8b586aa..4a713e5 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -3,6 +3,7 @@ const crypto = require('crypto'); const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; +const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; function extractToolNames(tools) { if (!Array.isArray(tools) || tools.length === 0) { @@ -28,6 +29,7 @@ function createToolSieveState() { capture: '', capturing: false, hasMeaningfulText: false, + recentTextTail: '', toolNameSent: false, toolName: '', toolArgsStart: -1, @@ -68,9 +70,7 @@ function processToolSieveChunk(state, chunk, toolNames) { const consumed = consumeToolCapture(state, toolNames); if (!consumed.ready) { if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { - if (hasMeaningfulText(state.capture)) { - state.hasMeaningfulText = true; - } + noteText(state, state.capture); events.push({ type: 'text', text: state.capture }); state.capture = ''; state.capturing = false; @@ -83,9 +83,7 @@ function processToolSieveChunk(state, chunk, toolNames) { state.capturing = false; resetIncrementalToolState(state); if (consumed.prefix) { - if (hasMeaningfulText(consumed.prefix)) { - state.hasMeaningfulText = true; - } + noteText(state, consumed.prefix); events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { @@ -105,9 +103,7 @@ function processToolSieveChunk(state, chunk, toolNames) { if (start >= 0) { const prefix = state.pending.slice(0, start); if (prefix) { - if (hasMeaningfulText(prefix)) { - state.hasMeaningfulText = true; - } + noteText(state, prefix); events.push({ type: 'text', text: prefix }); } state.capture = state.pending.slice(start); @@ -122,9 +118,7 @@ function processToolSieveChunk(state, chunk, toolNames) { break; } state.pending = hold; - if (hasMeaningfulText(safe)) { - state.hasMeaningfulText = true; - } + noteText(state, safe); events.push({ type: 'text', text: safe }); } return events; @@ -139,24 +133,18 @@ function flushToolSieve(state, toolNames) { const consumed = consumeToolCapture(state, toolNames); if (consumed.ready) { if (consumed.prefix) { - if (hasMeaningfulText(consumed.prefix)) { - state.hasMeaningfulText = true; - } + noteText(state, consumed.prefix); events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { events.push({ type: 'tool_calls', calls: consumed.calls }); } if (consumed.suffix) { - if (hasMeaningfulText(consumed.suffix)) { - state.hasMeaningfulText = true; - } + noteText(state, consumed.suffix); events.push({ type: 'text', text: consumed.suffix }); } } else if (state.capture) { - if (hasMeaningfulText(state.capture)) { - state.hasMeaningfulText = true; - } + noteText(state, state.capture); events.push({ type: 'text', text: state.capture }); } state.capture = ''; @@ -164,9 +152,7 @@ function flushToolSieve(state, toolNames) { resetIncrementalToolState(state); } if (state.pending) { - if (hasMeaningfulText(state.pending)) { - state.hasMeaningfulText = true; - } + noteText(state, state.pending); events.push({ type: 'text', text: state.pending }); state.pending = ''; } @@ -234,7 +220,7 @@ function consumeToolCapture(state, toolNames) { } const prefixPart = captured.slice(0, start); const suffixPart = captured.slice(obj.end); - if (!state.toolNameSent && (state.hasMeaningfulText || hasMeaningfulText(prefixPart) || hasMeaningfulText(suffixPart))) { + if (!state.toolNameSent && (hasMeaningfulText(prefixPart) || hasMeaningfulText(suffixPart) || looksLikeToolExampleContext(state.recentTextTail))) { return { ready: true, prefix: captured, @@ -285,7 +271,10 @@ function consumeToolCapture(state, toolNames) { function buildIncrementalToolDeltas(state) { const captured = state.capture || ''; - if (!captured || state.hasMeaningfulText) { + if (!captured) { + return []; + } + if (looksLikeToolExampleContext(state.recentTextTail)) { return []; } const lower = captured.toLowerCase(); @@ -651,6 +640,9 @@ function parseStandaloneToolCalls(text, toolNames) { if (!trimmed) { return []; } + if (looksLikeToolExampleContext(trimmed)) { + return []; + } const candidates = [trimmed]; if (trimmed.startsWith('```') && trimmed.endsWith('```')) { const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); @@ -856,6 +848,46 @@ function filterToolCalls(parsed, toolNames) { return out; } +function noteText(state, text) { + if (!state || !hasMeaningfulText(text)) { + return; + } + state.hasMeaningfulText = true; + state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); +} + +function appendTail(prev, next, max) { + const left = typeof prev === 'string' ? prev : ''; + const right = typeof next === 'string' ? next : ''; + if (!Number.isFinite(max) || max <= 0) { + return ''; + } + const combined = left + right; + if (combined.length <= max) { + return combined; + } + return combined.slice(combined.length - max); +} + +function looksLikeToolExampleContext(text) { + const t = toStringSafe(text).toLowerCase(); + if (!t) { + return false; + } + const cues = [ + '示例', + '例子', + 'for example', + 'example', + 'demo', + '请勿执行', + '不要执行', + 'do not execute', + '```', + ]; + return cues.some((cue) => t.includes(cue)); +} + function hasMeaningfulText(text) { return toStringSafe(text) !== ''; } diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js index ad1dc0b..c085436 100644 --- a/api/helpers/stream-tool-sieve.test.js +++ b/api/helpers/stream-tool-sieve.test.js @@ -83,6 +83,12 @@ test('parseStandaloneToolCalls only matches standalone payload and ignores mixed assert.equal(standaloneCalls.length, 1); }); +test('parseStandaloneToolCalls ignores fenced code block tool_call examples', () => { + const fenced = ['```json', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', '```'].join('\n'); + const calls = parseStandaloneToolCalls(fenced, ['read_file']); + assert.equal(calls.length, 0); +}); + test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => { const events = runSieve( [ @@ -165,3 +171,15 @@ test('sieve emits incremental tool_call_deltas for split arguments payload', () assert.equal(argsJoined.includes('"path":"README.MD"'), true); assert.equal(argsJoined.includes('"mode":"head"'), true); }); + +test('sieve still intercepts tool call after leading plain text without suffix', () => { + const events = runSieve( + ['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'], + ['read_file'], + ); + const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); + const leakedText = collectText(events); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('我将调用工具。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 30197d7..8c1435d 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -241,6 +241,35 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { } } +func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}", + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected fenced tool example to pass through as text, got %q", content) + } +} + func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -428,9 +457,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) { func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( - `data: {"p":"response/content","v":"前置正文A。"}`, + `data: {"p":"response/content","v":"下面是示例:"}`, `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, - `data: {"p":"response/content","v":"后置正文B。"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, `data: [DONE]`, ) rec := httptest.NewRecorder() @@ -457,7 +486,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { } } got := content.String() - if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") { + if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } if !strings.Contains(got, `"tool_calls"`) { @@ -468,6 +497,48 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { } } +func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"我将调用工具。"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + content := strings.Builder{} + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if c, ok := delta["content"].(string); ok { + content.WriteString(c) + } + } + } + got := content.String() + if !strings.Contains(got, "我将调用工具。") { + t.Fatalf("expected leading text to keep streaming, got=%q", got) + } + if strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("unexpected raw tool json leak, got=%q", got) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) { h := &Handler{} spaces := strings.Repeat(" ", 200) diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index d890314..e5d6b77 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -11,6 +11,7 @@ type toolStreamSieveState struct { capture strings.Builder capturing bool hasMeaningfulText bool + recentTextTail string toolNameSent bool toolName string toolArgsStart int @@ -32,6 +33,7 @@ type toolCallDelta struct { } const toolSieveCaptureLimit = 8 * 1024 +const toolSieveContextTailLimit = 256 func (s *toolStreamSieveState) resetIncrementalToolState() { s.toolNameSent = false @@ -67,9 +69,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capture.Reset() state.capturing = false state.resetIncrementalToolState() - if strings.TrimSpace(content) != "" { - state.hasMeaningfulText = true - } + state.noteText(content) events = append(events, toolStreamEvent{Content: content}) continue } @@ -79,9 +79,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capturing = false state.resetIncrementalToolState() if prefix != "" { - if strings.TrimSpace(prefix) != "" { - state.hasMeaningfulText = true - } + state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } if len(calls) > 0 { @@ -101,9 +99,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames if start >= 0 { prefix := pending[:start] if prefix != "" { - if strings.TrimSpace(prefix) != "" { - state.hasMeaningfulText = true - } + state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } state.pending.Reset() @@ -119,9 +115,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames } state.pending.Reset() state.pending.WriteString(hold) - if strings.TrimSpace(safe) != "" { - state.hasMeaningfulText = true - } + state.noteText(safe) events = append(events, toolStreamEvent{Content: safe}) } @@ -137,26 +131,20 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) if ready { if consumedPrefix != "" { - if strings.TrimSpace(consumedPrefix) != "" { - state.hasMeaningfulText = true - } + state.noteText(consumedPrefix) events = append(events, toolStreamEvent{Content: consumedPrefix}) } if len(consumedCalls) > 0 { events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) } if consumedSuffix != "" { - if strings.TrimSpace(consumedSuffix) != "" { - state.hasMeaningfulText = true - } + state.noteText(consumedSuffix) events = append(events, toolStreamEvent{Content: consumedSuffix}) } } else { content := state.capture.String() if content != "" { - if strings.TrimSpace(content) != "" { - state.hasMeaningfulText = true - } + state.noteText(content) events = append(events, toolStreamEvent{Content: content}) } } @@ -166,9 +154,7 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea } if state.pending.Len() > 0 { content := state.pending.String() - if strings.TrimSpace(content) != "" { - state.hasMeaningfulText = true - } + state.noteText(content) events = append(events, toolStreamEvent{Content: content}) state.pending.Reset() } @@ -241,7 +227,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix } prefixPart := captured[:start] suffixPart := captured[end:] - if !state.toolNameSent && (state.hasMeaningfulText || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "") { + if !state.toolNameSent && (strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" || looksLikeToolExampleContext(state.recentTextTail)) { return captured, nil, "", true } parsed := util.ParseStandaloneToolCalls(obj, toolNames) @@ -304,7 +290,10 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) { func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { captured := state.capture.String() - if captured == "" || state.hasMeaningfulText { + if captured == "" { + return nil + } + if looksLikeToolExampleContext(state.recentTextTail) { return nil } lower := strings.ToLower(captured) @@ -618,3 +607,46 @@ func skipSpaces(text string, i int) int { } return i } + +func (s *toolStreamSieveState) noteText(content string) { + if strings.TrimSpace(content) == "" { + return + } + s.hasMeaningfulText = true + s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) +} + +func appendTail(prev, next string, max int) string { + if max <= 0 { + return "" + } + combined := prev + next + if len(combined) <= max { + return combined + } + return combined[len(combined)-max:] +} + +func looksLikeToolExampleContext(text string) bool { + t := strings.ToLower(strings.TrimSpace(text)) + if t == "" { + return false + } + cues := []string{ + "示例", + "例子", + "for example", + "example", + "demo", + "请勿执行", + "不要执行", + "do not execute", + "```", + } + for _, cue := range cues { + if strings.Contains(t, cue) { + return true + } + } + return false +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go index 4760546..decb96e 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls.go @@ -41,6 +41,9 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed if trimmed == "" { return nil } + if looksLikeToolExampleContext(trimmed) { + return nil + } candidates := []string{trimmed} if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") { if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 { @@ -313,6 +316,30 @@ func extractJSONObject(text string, start int) (string, int, bool) { return "", 0, false } +func looksLikeToolExampleContext(text string) bool { + t := strings.ToLower(strings.TrimSpace(text)) + if t == "" { + return false + } + cues := []string{ + "```", + "示例", + "例子", + "for example", + "example", + "demo", + "请勿执行", + "不要执行", + "do not execute", + } + for _, cue := range cues { + if strings.Contains(t, cue) { + return true + } + } + return false +} + func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { out := make([]map[string]any, 0, len(calls)) for _, c := range calls { diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 8a29a18..509299c 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -75,3 +75,10 @@ func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) { t.Fatalf("expected standalone parser to match, got %#v", calls) } } + +func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) { + fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```" + if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) + } +} From f2b10992cc0ba721e40edcd58b43c57c4822ceff Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 16:52:16 +0800 Subject: [PATCH 06/52] test: Introduce comprehensive edge case tests for various internal packages including SSE, Claude, Auth, Account, Config, Deepseek, Admin, and Util. --- .gitignore | 3 + internal/account/pool_edge_test.go | 249 +++++++ internal/adapter/claude/handler_util_test.go | 348 ++++++++++ internal/adapter/openai/handler.go | 4 +- .../adapter/openai/handler_toolcall_test.go | 10 +- internal/admin/helpers_edge_test.go | 240 +++++++ internal/auth/auth_edge_test.go | 375 +++++++++++ internal/config/config_edge_test.go | 445 ++++++++++++ internal/deepseek/deepseek_edge_test.go | 165 +++++ internal/sse/consumer_edge_test.go | 140 ++++ internal/sse/line_edge_test.go | 70 ++ internal/sse/parser_edge_test.go | 631 ++++++++++++++++++ internal/sse/stream_edge_test.go | 177 +++++ internal/util/util_edge_test.go | 441 ++++++++++++ 14 files changed, 3291 insertions(+), 7 deletions(-) create mode 100644 internal/account/pool_edge_test.go create mode 100644 internal/adapter/claude/handler_util_test.go create mode 100644 internal/admin/helpers_edge_test.go create mode 100644 internal/auth/auth_edge_test.go create mode 100644 internal/config/config_edge_test.go create mode 100644 internal/deepseek/deepseek_edge_test.go create mode 100644 internal/sse/consumer_edge_test.go create mode 100644 internal/sse/line_edge_test.go create mode 100644 internal/sse/parser_edge_test.go create mode 100644 internal/sse/stream_edge_test.go create mode 100644 internal/util/util_edge_test.go diff --git a/.gitignore b/.gitignore index 5f776e2..422c203 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,9 @@ ds2api-tests htmlcov/ .pytest_cache/ .tox/ +*.coverprofile +coverage*.out +cover/ # Misc *.pyc diff --git a/internal/account/pool_edge_test.go b/internal/account/pool_edge_test.go new file mode 100644 index 0000000..6e90823 --- /dev/null +++ b/internal/account/pool_edge_test.go @@ -0,0 +1,249 @@ +package account + +import ( + "context" + "sync" + "testing" + "time" + + "ds2api/internal/config" +) + +// ─── Pool edge cases ───────────────────────────────────────────────── + +func TestPoolEmptyNoAccounts(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + pool := NewPool(config.LoadStore()) + if _, ok := pool.Acquire("", nil); ok { + t.Fatal("expected acquire to fail with no accounts") + } + status := pool.Status() + if total, ok := status["total"].(int); !ok || total != 0 { + t.Fatalf("unexpected total: %#v", status["total"]) + } +} + +func TestPoolReleaseNonExistentAccount(t *testing.T) { + pool := newPoolForTest(t, "2") + pool.Release("nonexistent@example.com") // should not panic +} + +func TestPoolReleaseAlreadyReleased(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected acquire success") + } + pool.Release(acc.Identifier()) + pool.Release(acc.Identifier()) // double release should not panic +} + +func TestPoolAcquireTargetNotFound(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("nonexistent@example.com", nil); ok { + t.Fatal("expected acquire to fail for non-existent target") + } +} + +func TestPoolAcquireWithExclusionList(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true}) + if !ok { + t.Fatal("expected acquire success with exclusion") + } + if acc.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier()) + } + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireAllExcluded(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("", map[string]bool{ + "acc1@example.com": true, + "acc2@example.com": true, + }); ok { + t.Fatal("expected acquire to fail when all accounts excluded") + } +} + +func TestPoolStatusFields(t *testing.T) { + pool := newPoolForTest(t, "2") + status := pool.Status() + + // Check all expected fields are present + for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} { + if _, ok := status[key]; !ok { + t.Fatalf("missing status field: %s", key) + } + } +} + +func TestPoolStatusAccountDetails(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, _ := pool.Acquire("acc1@example.com", nil) + + status := pool.Status() + inUseAccounts, ok := status["in_use_accounts"].([]string) + if !ok { + t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"]) + } + found := false + for _, id := range inUseAccounts { + if id == "acc1@example.com" { + found = true + break + } + } + if !found { + t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts) + } + if status["in_use"] != 1 { + t.Fatalf("expected 1 in_use, got %v", status["in_use"]) + } + + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireWaitContextCancelled(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + // Exhaust the pool + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire to succeed") + } + + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + var waitOK bool + go func() { + defer wg.Done() + _, waitOK = pool.AcquireWait(ctx, "", nil) + }() + + // Wait until queued + waitForWaitingCount(t, pool, 1) + + // Cancel context + cancel() + + wg.Wait() + if waitOK { + t.Fatal("expected acquire to fail after context cancellation") + } + + pool.Release(first.Identifier()) +} + +func TestPoolAcquireWaitTargetAccount(t *testing.T) { + pool := newPoolForTest(t, "1") + // Exhaust acc1 + acc1, ok := pool.Acquire("acc1@example.com", nil) + if !ok { + t.Fatal("expected acquire acc1 success") + } + + // Acquire acc2 directly (should succeed since acc2 is free) + ctx := context.Background() + acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil) + if !ok { + t.Fatal("expected acquire acc2 success via AcquireWait") + } + if acc2.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2, got %q", acc2.Identifier()) + } + + pool.Release(acc1.Identifier()) + pool.Release(acc2.Identifier()) +} + +func TestPoolMaxQueueSizeOverride(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 5 { + t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"]) + } +} + +func TestPoolQueueSizeAliasEnv(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 7 { + t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"]) + } +} + +func TestPoolMultipleAcquireReleaseCycles(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + for i := 0; i < 10; i++ { + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatalf("acquire failed at cycle %d", i) + } + pool.Release(acc.Identifier()) + } +} + +func TestPoolConcurrentAcquireWait(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire success") + } + + const waiters = 3 + results := make(chan bool, waiters) + + for i := 0; i < waiters; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, ok := pool.AcquireWait(ctx, "", nil) + results <- ok + }() + } + + // Wait for all to be queued (only 1 can queue) + time.Sleep(50 * time.Millisecond) + + // Release and allow queued requests to proceed + pool.Release(first.Identifier()) + + successCount := 0 + timeoutCount := 0 + for i := 0; i < waiters; i++ { + select { + case ok := <-results: + if ok { + successCount++ + // Release for next waiter + pool.Release("acc1@example.com") + } else { + timeoutCount++ + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for results") + } + } + + // At least 1 should succeed; 2 may fail due to queue limit + if successCount < 1 { + t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount) + } +} diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go new file mode 100644 index 0000000..73d2fab --- /dev/null +++ b/internal/adapter/claude/handler_util_test.go @@ -0,0 +1,348 @@ +package claude + +import ( + "testing" +) + +// ─── normalizeClaudeMessages ───────────────────────────────────────── + +func TestNormalizeClaudeMessagesSimpleString(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected 1 message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["content"] != "Hello" { + t.Fatalf("expected 'Hello', got %v", m["content"]) + } +} + +func TestNormalizeClaudeMessagesArrayContent(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "line1\nline2" { + t.Fatalf("expected joined text, got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesToolResult(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "tool_result", "content": "tool output"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "tool output" { + t.Fatalf("expected 'tool output', got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { + msgs := []any{"not a map", 42} + got := normalizeClaudeMessages(msgs) + if len(got) != 0 { + t.Fatalf("expected 0 messages for non-map items, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesEmpty(t *testing.T) { + got := normalizeClaudeMessages(nil) + if len(got) != 0 { + t.Fatalf("expected 0, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) { + msgs := []any{ + map[string]any{"role": "assistant", "content": "response"}, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected 'assistant', got %q", m["role"]) + } +} + +func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Hello"}, + map[string]any{"type": "image", "source": "data:..."}, + map[string]any{"type": "text", "text": "World"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "Hello\nWorld" { + t.Fatalf("expected only text parts joined, got %q", m["content"]) + } +} + +// ─── buildClaudeToolPrompt ─────────────────────────────────────────── + +func TestBuildClaudeToolPromptSingleTool(t *testing.T) { + tools := []any{ + map[string]any{ + "name": "search", + "description": "Search the web", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + } + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt") + } + // Should contain tool name and description + if !containsStr(prompt, "search") { + t.Fatalf("expected 'search' in prompt") + } + if !containsStr(prompt, "Search the web") { + t.Fatalf("expected description in prompt") + } + if !containsStr(prompt, "tool_calls") { + t.Fatalf("expected tool_calls instruction in prompt") + } +} + +func TestBuildClaudeToolPromptMultipleTools(t *testing.T) { + tools := []any{ + map[string]any{"name": "tool1", "description": "desc1"}, + map[string]any{"name": "tool2", "description": "desc2"}, + } + prompt := buildClaudeToolPrompt(tools) + if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") { + t.Fatalf("expected both tools in prompt") + } +} + +func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) { + tools := []any{"not a map"} + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt even with invalid tools") + } + // Should still contain the intro and instruction + if !containsStr(prompt, "You are Claude") { + t.Fatalf("expected intro in prompt") + } +} + +// ─── hasSystemMessage ──────────────────────────────────────────────── + +func TestHasSystemMessageTrue(t *testing.T) { + msgs := []any{ + map[string]any{"role": "system", "content": "You are a helper"}, + map[string]any{"role": "user", "content": "Hi"}, + } + if !hasSystemMessage(msgs) { + t.Fatal("expected true") + } +} + +func TestHasSystemMessageFalse(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hi"}, + map[string]any{"role": "assistant", "content": "Hello"}, + } + if hasSystemMessage(msgs) { + t.Fatal("expected false") + } +} + +func TestHasSystemMessageEmpty(t *testing.T) { + if hasSystemMessage(nil) { + t.Fatal("expected false for nil") + } +} + +func TestHasSystemMessageNonMap(t *testing.T) { + msgs := []any{"not a map"} + if hasSystemMessage(msgs) { + t.Fatal("expected false for non-map") + } +} + +// ─── extractClaudeToolNames ────────────────────────────────────────── + +func TestExtractClaudeToolNamesSingle(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "search" { + t.Fatalf("expected [search], got %v", names) + } +} + +func TestExtractClaudeToolNamesMultiple(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + map[string]any{"name": "calculate"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 2 { + t.Fatalf("expected 2 names, got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) { + tools := []any{ + map[string]any{"name": ""}, + map[string]any{"name": "valid"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "valid" { + t.Fatalf("expected [valid], got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) { + tools := []any{"not a map", 42} + names := extractClaudeToolNames(tools) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +func TestExtractClaudeToolNamesNil(t *testing.T) { + names := extractClaudeToolNames(nil) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +// ─── toMessageMaps ─────────────────────────────────────────────────── + +func TestToMessageMapsNormal(t *testing.T) { + input := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1, got %d", len(got)) + } +} + +func TestToMessageMapsNonSlice(t *testing.T) { + got := toMessageMaps("not a slice") + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestToMessageMapsSkipsNonMap(t *testing.T) { + input := []any{"string", map[string]any{"role": "user"}, 42} + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1 map, got %d", len(got)) + } +} + +func TestToMessageMapsNil(t *testing.T) { + got := toMessageMaps(nil) + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +// ─── extractMessageContent ────────────────────────────────────────── + +func TestExtractMessageContentString(t *testing.T) { + if got := extractMessageContent("hello"); got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestExtractMessageContentArray(t *testing.T) { + input := []any{"part1", "part2"} + got := extractMessageContent(input) + if got != "part1\npart2" { + t.Fatalf("expected joined, got %q", got) + } +} + +func TestExtractMessageContentOther(t *testing.T) { + got := extractMessageContent(42) + if got != "42" { + t.Fatalf("expected '42', got %q", got) + } +} + +func TestExtractMessageContentNil(t *testing.T) { + got := extractMessageContent(nil) + if got != "" { + t.Fatalf("expected '', got %q", got) + } +} + +// ─── cloneMap ──────────────────────────────────────────────────────── + +func TestCloneMapBasic(t *testing.T) { + original := map[string]any{"a": 1, "b": "hello"} + clone := cloneMap(original) + original["a"] = 999 + if clone["a"] != 1 { + t.Fatalf("expected 1, got %v", clone["a"]) + } + if clone["b"] != "hello" { + t.Fatalf("expected 'hello', got %v", clone["b"]) + } +} + +func TestCloneMapEmpty(t *testing.T) { + clone := cloneMap(map[string]any{}) + if len(clone) != 0 { + t.Fatalf("expected empty, got %v", clone) + } +} + +func TestCloneMapNested(t *testing.T) { + // cloneMap is shallow, so nested maps share references + inner := map[string]any{"key": "value"} + original := map[string]any{"nested": inner} + clone := cloneMap(original) + // Shallow clone means inner is shared + inner["key"] = "modified" + cloneNested := clone["nested"].(map[string]any) + if cloneNested["key"] != "modified" { + t.Fatal("expected shallow clone to share nested references") + } +} + +// helper +func containsStr(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub)) +} + +func findSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 962e450..4de28b7 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -120,10 +120,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, toolNames) } -func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 8c1435d..3cab68c 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -128,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -161,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -189,7 +189,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -220,7 +220,7 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -249,7 +249,7 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } diff --git a/internal/admin/helpers_edge_test.go b/internal/admin/helpers_edge_test.go new file mode 100644 index 0000000..2a0bf20 --- /dev/null +++ b/internal/admin/helpers_edge_test.go @@ -0,0 +1,240 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "ds2api/internal/config" +) + +// ─── reverseAccounts ───────────────────────────────────────────────── + +func TestReverseAccountsEmpty(t *testing.T) { + a := []config.Account{} + reverseAccounts(a) + if len(a) != 0 { + t.Fatal("expected empty") + } +} + +func TestReverseAccountsTwoElements(t *testing.T) { + a := []config.Account{ + {Email: "a@test.com"}, + {Email: "b@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" { + t.Fatalf("unexpected order after reverse: %v", a) + } +} + +func TestReverseAccountsThreeElements(t *testing.T) { + a := []config.Account{ + {Email: "1@test.com"}, + {Email: "2@test.com"}, + {Email: "3@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" { + t.Fatalf("unexpected order: %v", a) + } +} + +// ─── intFromQuery edge cases ───────────────────────────────────────── + +func TestIntFromQueryPresent(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=5", nil) + if got := intFromQuery(req, "limit", 10); got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +func TestIntFromQueryMissing(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10, got %d", got) + } +} + +func TestIntFromQueryInvalid(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=abc", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10 for invalid, got %d", got) + } +} + +func TestIntFromQueryNegative(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=-3", nil) + if got := intFromQuery(req, "limit", 10); got != -3 { + t.Fatalf("expected -3, got %d", got) + } +} + +func TestIntFromQueryZero(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=0", nil) + if got := intFromQuery(req, "limit", 10); got != 0 { + t.Fatalf("expected 0, got %d", got) + } +} + +// ─── nilIfEmpty ────────────────────────────────────────────────────── + +func TestNilIfEmptyEmpty(t *testing.T) { + if nilIfEmpty("") != nil { + t.Fatal("expected nil for empty string") + } +} + +func TestNilIfEmptyNonEmpty(t *testing.T) { + if nilIfEmpty("hello") != "hello" { + t.Fatal("expected 'hello'") + } +} + +// ─── nilIfZero ─────────────────────────────────────────────────────── + +func TestNilIfZeroZero(t *testing.T) { + if nilIfZero(0) != nil { + t.Fatal("expected nil for zero") + } +} + +func TestNilIfZeroNonZero(t *testing.T) { + if nilIfZero(42) != int64(42) { + t.Fatal("expected 42") + } +} + +func TestNilIfZeroNegative(t *testing.T) { + if nilIfZero(-1) != int64(-1) { + t.Fatal("expected -1") + } +} + +// ─── toStringSlice ─────────────────────────────────────────────────── + +func TestToStringSliceFromAnySlice(t *testing.T) { + input := []any{"a", "b", "c"} + got, ok := toStringSlice(input) + if !ok || len(got) != 3 { + t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok) + } + if got[0] != "a" || got[1] != "b" || got[2] != "c" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromMixed(t *testing.T) { + input := []any{"hello", 42, true} + got, ok := toStringSlice(input) + if !ok { + t.Fatal("expected ok for mixed types") + } + if got[0] != "hello" || got[1] != "42" || got[2] != "true" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromNonSlice(t *testing.T) { + _, ok := toStringSlice("not a slice") + if ok { + t.Fatal("expected not ok for string input") + } +} + +func TestToStringSliceFromNil(t *testing.T) { + _, ok := toStringSlice(nil) + if ok { + t.Fatal("expected not ok for nil input") + } +} + +func TestToStringSliceEmpty(t *testing.T) { + got, ok := toStringSlice([]any{}) + if !ok { + t.Fatal("expected ok for empty slice") + } + if len(got) != 0 { + t.Fatalf("expected empty result, got %#v", got) + } +} + +func TestToStringSliceTrimsWhitespace(t *testing.T) { + got, ok := toStringSlice([]any{" hello ", " world "}) + if !ok { + t.Fatal("expected ok") + } + if got[0] != "hello" || got[1] != "world" { + t.Fatalf("expected trimmed values, got %#v", got) + } +} + +// ─── toAccount edge cases ──────────────────────────────────────────── + +func TestToAccountAllFields(t *testing.T) { + acc := toAccount(map[string]any{ + "email": "user@test.com", + "mobile": "13800138000", + "password": "secret", + "token": "tok123", + }) + if acc.Email != "user@test.com" { + t.Fatalf("unexpected email: %q", acc.Email) + } + if acc.Mobile != "13800138000" { + t.Fatalf("unexpected mobile: %q", acc.Mobile) + } + if acc.Password != "secret" { + t.Fatalf("unexpected password: %q", acc.Password) + } + if acc.Token != "tok123" { + t.Fatalf("unexpected token: %q", acc.Token) + } +} + +func TestToAccountNumericValues(t *testing.T) { + acc := toAccount(map[string]any{ + "email": 12345, + }) + if acc.Email != "12345" { + t.Fatalf("expected numeric converted to string, got %q", acc.Email) + } +} + +// ─── fieldString edge cases ────────────────────────────────────────── + +func TestFieldStringNonString(t *testing.T) { + got := fieldString(map[string]any{"key": 42}, "key") + if got != "42" { + t.Fatalf("expected '42' for int, got %q", got) + } +} + +func TestFieldStringBool(t *testing.T) { + got := fieldString(map[string]any{"key": true}, "key") + if got != "true" { + t.Fatalf("expected 'true', got %q", got) + } +} + +func TestFieldStringWhitespace(t *testing.T) { + got := fieldString(map[string]any{"key": " hello "}, "key") + if got != "hello" { + t.Fatalf("expected trimmed 'hello', got %q", got) + } +} + +// ─── statusOr ──────────────────────────────────────────────────────── + +func TestStatusOrZeroReturnsDefault(t *testing.T) { + if got := statusOr(0, http.StatusOK); got != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, got) + } +} + +func TestStatusOrNonZeroReturnsValue(t *testing.T) { + if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest { + t.Fatalf("expected %d, got %d", http.StatusBadRequest, got) + } +} diff --git a/internal/auth/auth_edge_test.go b/internal/auth/auth_edge_test.go new file mode 100644 index 0000000..55c46ef --- /dev/null +++ b/internal/auth/auth_edge_test.go @@ -0,0 +1,375 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +// ─── extractCallerToken edge cases ─────────────────────────────────── + +func TestExtractCallerTokenBearerPrefix(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer my-token") + if got := extractCallerToken(req); got != "my-token" { + t.Fatalf("expected my-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "BEARER My-Token") + if got := extractCallerToken(req); got != "My-Token" { + t.Fatalf("expected My-Token, got %q", got) + } +} + +func TestExtractCallerTokenBearerEmpty(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer ") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for 'Bearer ', got %q", got) + } +} + +func TestExtractCallerTokenXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "x-api-key-token" { + t.Fatalf("expected x-api-key-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer bearer-token") + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "bearer-token" { + t.Fatalf("expected bearer-token, got %q", got) + } +} + +func TestExtractCallerTokenMissingHeaders(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for missing headers, got %q", got) + } +} + +func TestExtractCallerTokenNonBearerAuth(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Basic abc123") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for Basic auth, got %q", got) + } +} + +// ─── Context helpers ───────────────────────────────────────────────── + +func TestWithAuthAndFromContext(t *testing.T) { + a := &RequestAuth{DeepSeekToken: "test-token"} + ctx := WithAuth(context.Background(), a) + got, ok := FromContext(ctx) + if !ok || got.DeepSeekToken != "test-token" { + t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken) + } +} + +func TestFromContextMissing(t *testing.T) { + _, ok := FromContext(context.Background()) + if ok { + t.Fatal("expected not ok from empty context") + } +} + +// ─── RefreshToken edge cases ───────────────────────────────────────── + +func TestRefreshTokenNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestRefreshTokenEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for empty account ID") + } +} + +func TestRefreshTokenSuccess(t *testing.T) { + r := newTestResolver(t) + // First acquire an account + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + if !r.RefreshToken(context.Background(), a) { + t.Fatal("expected refresh to succeed") + } + if a.DeepSeekToken != "fresh-token" { + t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken) + } +} + +// ─── MarkTokenInvalid edge cases ───────────────────────────────────── + +func TestMarkTokenInvalidNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic, token should be unchanged for non-config + if a.DeepSeekToken != "" { + // Actually it does clear it; that's fine - let's check behavior + } +} + +func TestMarkTokenInvalidEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic +} + +func TestMarkTokenInvalidClearsToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + r.MarkTokenInvalid(a) + if a.DeepSeekToken != "" { + t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken) + } + if a.Account.Token != "" { + t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token) + } +} + +// ─── SwitchAccount edge cases ──────────────────────────────────────── + +func TestSwitchAccountNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.SwitchAccount(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestSwitchAccountNilTriedAccounts(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "new-token", nil + }) + + // First acquire + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + + oldID := a.AccountID + a.TriedAccounts = nil // test nil initialization in SwitchAccount + if !r.SwitchAccount(context.Background(), a) { + t.Fatal("expected switch to succeed") + } + if a.AccountID == oldID { + t.Fatalf("expected different account after switch") + } + r.Release(a) +} + +// ─── Release edge cases ───────────────────────────────────────────── + +func TestReleaseNilAuth(t *testing.T) { + r := newTestResolver(t) + r.Release(nil) // should not panic +} + +func TestReleaseNonConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false} + r.Release(a) // should not panic +} + +func TestReleaseEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: ""} + r.Release(a) // should not panic +} + +// ─── JWT edge cases ────────────────────────────────────────────────── + +func TestVerifyJWTInvalidFormat(t *testing.T) { + _, err := VerifyJWT("not-a-jwt") + if err == nil { + t.Fatal("expected error for invalid JWT format") + } +} + +func TestVerifyJWTInvalidSignature(t *testing.T) { + token, _ := CreateJWT(1) + // Tamper with the signature + parts := splitJWT(token) + if len(parts) == 3 { + tampered := parts[0] + "." + parts[1] + ".invalid_signature" + _, err := VerifyJWT(tampered) + if err == nil { + t.Fatal("expected error for tampered signature") + } + } +} + +func TestVerifyJWTExpired(t *testing.T) { + // Create a token with 0 hours expiry - will use default, so we can't easily test + // Instead test with bad payload + _, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid") + if err == nil { + t.Fatal("expected error for expired/invalid JWT") + } +} + +func TestCreateJWTDefaultExpiry(t *testing.T) { + token, err := CreateJWT(0) // should use default + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + _, err = VerifyJWT(token) + if err != nil { + t.Fatalf("verify jwt failed: %v", err) + } +} + +// ─── VerifyAdminRequest edge cases ─────────────────────────────────── + +func TestVerifyAdminRequestNoHeader(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for missing auth") + } +} + +func TestVerifyAdminRequestEmptyBearer(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer ") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for empty bearer") + } +} + +func TestVerifyAdminRequestWithAdminKey(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "test-admin-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer test-admin-key") + if err := VerifyAdminRequest(req); err != nil { + t.Fatalf("expected admin key accepted: %v", err) + } +} + +func TestVerifyAdminRequestInvalidCredentials(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "correct-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer wrong-key") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for wrong key") + } +} + +func TestVerifyAdminRequestBasicAuth(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Basic abc123") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for Basic auth") + } +} + +// ─── Determine with login failure ──────────────────────────────────── + +func TestDetermineWithLoginFailure(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@test.com","password":"pwd"}] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "", errors.New("login failed") + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + _, err := r.Determine(req) + if err == nil { + t.Fatal("expected error when login fails") + } +} + +// ─── Determine with target account ─────────────────────────────────── + +func TestDetermineWithTargetAccount(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "fresh-token", nil + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + req.Header.Set("X-Ds2-Target-Account", "acc2@test.com") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if a.AccountID != "acc2@test.com" { + t.Fatalf("expected target account acc2, got %q", a.AccountID) + } +} + +// helper +func splitJWT(token string) []string { + result := make([]string, 0, 3) + start := 0 + count := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + result = append(result, token[start:i]) + start = i + 1 + count++ + } + } + result = append(result, token[start:]) + return result +} diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go new file mode 100644 index 0000000..81cc7ec --- /dev/null +++ b/internal/config/config_edge_test.go @@ -0,0 +1,445 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" +) + +// ─── GetModelConfig edge cases ─────────────────────────────────────── + +func TestGetModelConfigDeepSeekChat(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat") + if !ok { + t.Fatal("expected ok for deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasoner(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner") + if !ok { + t.Fatal("expected ok for deepseek-reasoner") + } + if !thinking || search { + t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekChatSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat-search") + if !ok { + t.Fatal("expected ok for deepseek-chat-search") + } + if thinking || !search { + t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner-search") + if !ok { + t.Fatal("expected ok for deepseek-reasoner-search") + } + if !thinking || !search { + t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigCaseInsensitive(t *testing.T) { + thinking, search, ok := GetModelConfig("DeepSeek-Chat") + if !ok { + t.Fatal("expected ok for case-insensitive deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat") + } +} + +func TestGetModelConfigUnknownModel(t *testing.T) { + _, _, ok := GetModelConfig("gpt-4") + if ok { + t.Fatal("expected not ok for unknown model") + } +} + +func TestGetModelConfigEmpty(t *testing.T) { + _, _, ok := GetModelConfig("") + if ok { + t.Fatal("expected not ok for empty model") + } +} + +// ─── lower function ────────────────────────────────────────────────── + +func TestLowerFunction(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello", "hello"}, + {"ALLCAPS", "allcaps"}, + {"already-lower", "already-lower"}, + {"Mixed-CASE-123", "mixed-case-123"}, + {"", ""}, + } + for _, tc := range tests { + got := lower(tc.input) + if got != tc.expected { + t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ──────────────────── + +func TestConfigJSONRoundtrip(t *testing.T) { + cfg := Config{ + Keys: []string{"key1", "key2"}, + Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + }, + VercelSyncHash: "hash123", + VercelSyncTime: 1234567890, + AdditionalFields: map[string]any{ + "custom_field": "custom_value", + }, + } + + data, err := cfg.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var decoded Config + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" { + t.Fatalf("unexpected keys: %#v", decoded.Keys) + } + if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" { + t.Fatalf("unexpected accounts: %#v", decoded.Accounts) + } + if decoded.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping) + } + if decoded.VercelSyncHash != "hash123" { + t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash) + } + if decoded.AdditionalFields["custom_field"] != "custom_value" { + t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields) + } +} + +func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) { + raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}` + var cfg Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if cfg.AdditionalFields["my_custom_field"] != "hello" { + t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields) + } + // number_field should also be preserved + if cfg.AdditionalFields["number_field"] != float64(42) { + t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"]) + } +} + +// ─── Config.Clone ──────────────────────────────────────────────────── + +func TestConfigCloneIsDeepCopy(t *testing.T) { + cfg := Config{ + Keys: []string{"key1"}, + Accounts: []Account{{Email: "user@test.com", Token: "token"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + }, + AdditionalFields: map[string]any{"custom": "value"}, + } + + cloned := cfg.Clone() + + // Modify original + cfg.Keys[0] = "modified" + cfg.Accounts[0].Email = "modified@test.com" + cfg.ClaudeMapping["fast"] = "modified-model" + + // Cloned should not be affected + if cloned.Keys[0] != "key1" { + t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys) + } + if cloned.Accounts[0].Email != "user@test.com" { + t.Fatalf("clone accounts was affected: %#v", cloned.Accounts) + } + if cloned.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping) + } +} + +func TestConfigCloneNilMaps(t *testing.T) { + cfg := Config{ + Keys: []string{"k"}, + Accounts: nil, + } + cloned := cfg.Clone() + if len(cloned.Keys) != 1 { + t.Fatalf("unexpected keys length: %d", len(cloned.Keys)) + } + if cloned.Accounts != nil { + t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts) + } +} + +// ─── Account.Identifier edge cases ─────────────────────────────────── + +func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) { + acc := Account{Mobile: "13800138000", Token: "tok"} + if acc.Identifier() != "13800138000" { + t.Fatalf("expected mobile identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) { + acc := Account{Email: "user@test.com", Mobile: "13800138000"} + if acc.Identifier() != "user@test.com" { + t.Fatalf("expected email identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierEmptyAccount(t *testing.T) { + acc := Account{} + if acc.Identifier() != "" { + t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier()) + } +} + +// ─── normalizeConfigInput ──────────────────────────────────────────── + +func TestNormalizeConfigInputStripsQuotes(t *testing.T) { + got := normalizeConfigInput(`"base64:abc"`) + if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) { + t.Fatalf("expected quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) { + got := normalizeConfigInput("'some-value'") + if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") { + t.Fatalf("expected single quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) { + got := normalizeConfigInput(" hello ") + if got != "hello" { + t.Fatalf("expected trimmed, got %q", got) + } +} + +// ─── parseConfigString edge cases ──────────────────────────────────── + +func TestParseConfigStringPlainJSON(t *testing.T) { + cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["base64-key"],"accounts":[]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString("base64:" + b64) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringInvalidBase64(t *testing.T) { + _, err := parseConfigString("base64:!!!invalid!!!") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestParseConfigStringEmptyString(t *testing.T) { + _, err := parseConfigString("") + if err == nil { + t.Fatal("expected error for empty string") + } +} + +// ─── Store methods ─────────────────────────────────────────────────── + +func TestStoreSnapshotReturnsClone(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`) + store := LoadStore() + snap := store.Snapshot() + snap.Keys[0] = "modified" + if store.Keys()[0] != "k1" { + t.Fatal("snapshot modification should not affect store") + } +} + +func TestStoreHasAPIKeyMultipleKeys(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`) + store := LoadStore() + if !store.HasAPIKey("key1") { + t.Fatal("expected key1 found") + } + if !store.HasAPIKey("key2") { + t.Fatal("expected key2 found") + } + if !store.HasAPIKey("key3") { + t.Fatal("expected key3 found") + } + if store.HasAPIKey("nonexistent") { + t.Fatal("expected nonexistent key not found") + } +} + +func TestStoreFindAccountNotFound(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`) + store := LoadStore() + _, ok := store.FindAccount("nonexistent@test.com") + if ok { + t.Fatal("expected account not found") + } +} + +func TestStoreIsEnvBacked(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.IsEnvBacked() { + t.Fatal("expected env-backed store") + } +} + +func TestStoreReplace(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + newCfg := Config{ + Keys: []string{"new-key"}, + Accounts: []Account{{Email: "new@test.com"}}, + } + if err := store.Replace(newCfg); err != nil { + t.Fatalf("replace error: %v", err) + } + if !store.HasAPIKey("new-key") { + t.Fatal("expected new key after replace") + } + if store.HasAPIKey("k1") { + t.Fatal("expected old key removed after replace") + } +} + +func TestStoreUpdate(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + err := store.Update(func(cfg *Config) error { + cfg.Keys = append(cfg.Keys, "k2") + return nil + }) + if err != nil { + t.Fatalf("update error: %v", err) + } + if !store.HasAPIKey("k2") { + t.Fatal("expected k2 after update") + } +} + +func TestStoreClaudeMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := LoadStore() + mapping := store.ClaudeMapping() + if mapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected fast mapping: %q", mapping["fast"]) + } + if mapping["slow"] != "deepseek-reasoner" { + t.Fatalf("unexpected slow mapping: %q", mapping["slow"]) + } +} + +func TestStoreClaudeMappingEmpty(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + mapping := store.ClaudeMapping() + // Even without config mapping, there are defaults + if mapping == nil { + t.Fatal("expected non-nil mapping (may contain defaults)") + } +} + +func TestStoreSetVercelSync(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + if err := store.SetVercelSync("hash123", 1234567890); err != nil { + t.Fatalf("setVercelSync error: %v", err) + } + snap := store.Snapshot() + if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 { + t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime) + } +} + +func TestStoreExportJSONAndBase64(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`) + store := LoadStore() + jsonStr, b64Str, err := store.ExportJSONAndBase64() + if err != nil { + t.Fatalf("export error: %v", err) + } + if !strings.Contains(jsonStr, "export-key") { + t.Fatalf("expected JSON to contain key: %q", jsonStr) + } + decoded, err := base64.StdEncoding.DecodeString(b64Str) + if err != nil { + t.Fatalf("base64 decode error: %v", err) + } + if !strings.Contains(string(decoded), "export-key") { + t.Fatalf("expected base64-decoded to contain key: %q", string(decoded)) + } +} + +// ─── OpenAIModelsResponse / ClaudeModelsResponse ───────────────────── + +func TestOpenAIModelsResponse(t *testing.T) { + resp := OpenAIModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} + +func TestClaudeModelsResponse(t *testing.T) { + resp := ClaudeModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} diff --git a/internal/deepseek/deepseek_edge_test.go b/internal/deepseek/deepseek_edge_test.go new file mode 100644 index 0000000..92e6952 --- /dev/null +++ b/internal/deepseek/deepseek_edge_test.go @@ -0,0 +1,165 @@ +package deepseek + +import ( + "context" + "testing" +) + +// ─── toFloat64 edge cases ──────────────────────────────────────────── + +func TestToFloat64FromFloat64(t *testing.T) { + if got := toFloat64(float64(3.14), 0); got != 3.14 { + t.Fatalf("expected 3.14, got %f", got) + } +} + +func TestToFloat64FromInt(t *testing.T) { + if got := toFloat64(42, 0); got != 42.0 { + t.Fatalf("expected 42.0, got %f", got) + } +} + +func TestToFloat64FromInt64(t *testing.T) { + if got := toFloat64(int64(100), 0); got != 100.0 { + t.Fatalf("expected 100.0, got %f", got) + } +} + +func TestToFloat64FromStringDefault(t *testing.T) { + if got := toFloat64("42", 99.0); got != 99.0 { + t.Fatalf("expected default 99.0, got %f", got) + } +} + +func TestToFloat64FromNilDefault(t *testing.T) { + if got := toFloat64(nil, 5.5); got != 5.5 { + t.Fatalf("expected default 5.5, got %f", got) + } +} + +func TestToFloat64FromBoolDefault(t *testing.T) { + if got := toFloat64(true, 1.0); got != 1.0 { + t.Fatalf("expected default 1.0, got %f", got) + } +} + +// ─── toInt64 edge cases ────────────────────────────────────────────── + +func TestToInt64FromFloat64(t *testing.T) { + if got := toInt64(float64(42.9), 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt(t *testing.T) { + if got := toInt64(42, 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt64(t *testing.T) { + if got := toInt64(int64(100), 0); got != 100 { + t.Fatalf("expected 100, got %d", got) + } +} + +func TestToInt64FromStringDefault(t *testing.T) { + if got := toInt64("42", 99); got != 99 { + t.Fatalf("expected default 99, got %d", got) + } +} + +func TestToInt64FromNilDefault(t *testing.T) { + if got := toInt64(nil, 7); got != 7 { + t.Fatalf("expected default 7, got %d", got) + } +} + +// ─── BuildPowHeader edge cases ─────────────────────────────────────── + +func TestBuildPowHeaderBasicChallenge(t *testing.T) { + challenge := map[string]any{ + "algorithm": "DeepSeekHashV1", + "challenge": "abc123", + "salt": "salt456", + "signature": "sig789", + "target_path": "/path", + } + result, err := BuildPowHeader(challenge, 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Fatal("expected non-empty result") + } +} + +func TestBuildPowHeaderEmptyChallenge(t *testing.T) { + result, err := BuildPowHeader(map[string]any{}, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should produce a base64 encoded JSON with nil values + if result == "" { + t.Fatal("expected non-empty result for empty challenge") + } +} + +// ─── PowSolver pool size ───────────────────────────────────────────── + +func TestPowPoolSizeFromEnvDefault(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default pool size, got %d", got) + } +} + +func TestPowPoolSizeFromEnvInvalid(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "abc") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default for invalid, got %d", got) + } +} + +func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "5") + got := powPoolSizeFromEnv() + if got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +// ─── NewClient ─────────────────────────────────────────────────────── + +func TestNewClientInitialState(t *testing.T) { + client := NewClient(nil, nil) + if client.powSolver == nil { + t.Fatal("expected powSolver to be initialized") + } +} + +func TestNewClientPreloadPowIdempotent(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + client := NewClient(nil, nil) + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("first preload failed: %v", err) + } + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("second preload failed: %v", err) + } +} + +// ─── PowSolver init and module pool ────────────────────────────────── + +func TestPowSolverPoolSizeMatchesEnv(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "2") + solver := NewPowSolver("test.wasm") + if err := solver.init(context.Background()); err != nil { + t.Fatalf("init failed: %v", err) + } + if cap(solver.pool) != 2 { + t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool)) + } +} diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go new file mode 100644 index 0000000..8f78f01 --- /dev/null +++ b/internal/sse/consumer_edge_test.go @@ -0,0 +1,140 @@ +package sse + +import ( + "io" + "net/http" + "strings" + "testing" +) + +// ─── CollectStream edge cases ──────────────────────────────────────── + +func makeHTTPResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestCollectStreamEmpty(t *testing.T) { + resp := makeHTTPResponse("") + result := CollectStream(resp, false, false) + if result.Text != "" || result.Thinking != "" { + t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking) + } +} + +func TestCollectStreamTextOnly(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" World\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello World" { + t.Fatalf("expected 'Hello World', got %q", result.Text) + } + if result.Thinking != "" { + t.Fatalf("expected no thinking, got %q", result.Thinking) + } +} + +func TestCollectStreamThinkingAndText(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Thinking..." { + t.Fatalf("expected 'Thinking...', got %q", result.Thinking) + } + if result.Text != "Answer" { + t.Fatalf("expected 'Answer', got %q", result.Text) + } +} + +func TestCollectStreamOnlyThinking(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Only thinking" { + t.Fatalf("expected 'Only thinking', got %q", result.Thinking) + } + if result.Text != "" { + t.Fatalf("expected empty text, got %q", result.Text) + } +} + +func TestCollectStreamSkipsInvalidLines(t *testing.T) { + resp := makeHTTPResponse( + "event: comment\n" + + "data: invalid_json\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "valid" { + t.Fatalf("expected 'valid', got %q", result.Text) + } +} + +func TestCollectStreamWithFragments(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Think" { + t.Fatalf("expected 'Think' thinking, got %q", result.Thinking) + } + if result.Text != "Done" { + t.Fatalf("expected 'Done' text, got %q", result.Text) + } +} + +func TestCollectStreamWithCitation(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" more\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + // CollectStream does NOT filter citations (that's done by the adapters) + // So citations are passed through as-is + if !strings.Contains(result.Text, "[citation:1]") { + t.Fatalf("expected citations to be passed through, got %q", result.Text) + } + if result.Text != "Hello[citation:1] cited text more" { + t.Fatalf("expected full text with citation, got %q", result.Text) + } +} + +func TestCollectStreamMultipleThinkingChunks(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" + + "data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "part1 part2" { + t.Fatalf("expected 'part1 part2', got %q", result.Thinking) + } +} + +func TestCollectStreamStatusFinished(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello" { + t.Fatalf("expected 'Hello', got %q", result.Text) + } +} diff --git a/internal/sse/line_edge_test.go b/internal/sse/line_edge_test.go new file mode 100644 index 0000000..2ae53a6 --- /dev/null +++ b/internal/sse/line_edge_test.go @@ -0,0 +1,70 @@ +package sse + +import "testing" + +func TestParseDeepSeekContentLineNotParsed(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed") + } + if res.NextType != "text" { + t.Fatalf("expected nextType preserved, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking") + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" { + t.Fatalf("unexpected parts: %#v", res.Parts) + } +} + +func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) { + res := ParseDeepSeekContentLine( + []byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`), + true, "thinking", + ) + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if res.NextType != "text" { + t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text") + if !res.ContentFilter { + t.Fatal("expected content filter flag") + } + if res.ErrorMessage == "" { + t.Fatal("expected error message on content filter") + } +} + +func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text") + if !res.Parsed || !res.Stop { + t.Fatalf("expected parsed stop: %#v", res) + } + if res.ErrorMessage == "" { + t.Fatal("expected non-empty error message") + } +} + +func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed for broken JSON") + } +} + +func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) { + res := ParseDeepSeekContentLine([]byte{}, false, "text") + if res.Parsed { + t.Fatal("expected not parsed for empty bytes") + } +} diff --git a/internal/sse/parser_edge_test.go b/internal/sse/parser_edge_test.go new file mode 100644 index 0000000..c851c1f --- /dev/null +++ b/internal/sse/parser_edge_test.go @@ -0,0 +1,631 @@ +package sse + +import "testing" + +// ─── ParseDeepSeekSSELine edge cases ───────────────────────────────── + +func TestParseDeepSeekSSELineEmptyLine(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("")) + if ok { + t.Fatal("expected not parsed for empty line") + } +} + +func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("event: message")) + if ok { + t.Fatal("expected not parsed for non-data line") + } +} + +func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json")) + if ok { + t.Fatal("expected not parsed for invalid JSON") + } +} + +func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte(" ")) + if ok { + t.Fatal("expected not parsed for whitespace-only line") + } +} + +func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) { + chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `)) + if !ok || done { + t.Fatalf("expected parsed chunk for spaced data line") + } + if chunk["v"] != "hello" { + t.Fatalf("unexpected chunk: %#v", chunk) + } +} + +// ─── shouldSkipPath edge cases ─────────────────────────────────────── + +func TestShouldSkipPathQuasiStatus(t *testing.T) { + if !shouldSkipPath("response/quasi_status") { + t.Fatal("expected skip for quasi_status path") + } +} + +func TestShouldSkipPathElapsedSecs(t *testing.T) { + if !shouldSkipPath("response/elapsed_secs") { + t.Fatal("expected skip for elapsed_secs path") + } +} + +func TestShouldSkipPathTokenUsage(t *testing.T) { + if !shouldSkipPath("response/token_usage") { + t.Fatal("expected skip for token_usage path") + } +} + +func TestShouldSkipPathPendingFragment(t *testing.T) { + if !shouldSkipPath("response/pending_fragment") { + t.Fatal("expected skip for pending_fragment path") + } +} + +func TestShouldSkipPathConversationMode(t *testing.T) { + if !shouldSkipPath("response/conversation_mode") { + t.Fatal("expected skip for conversation_mode path") + } +} + +func TestShouldSkipPathSearchStatus(t *testing.T) { + if !shouldSkipPath("response/search_status") { + t.Fatal("expected skip for search_status path") + } +} + +func TestShouldSkipPathFragmentStatus(t *testing.T) { + if !shouldSkipPath("response/fragments/-1/status") { + t.Fatal("expected skip for fragment -1 status") + } + if !shouldSkipPath("response/fragments/-2/status") { + t.Fatal("expected skip for fragment -2 status") + } + if !shouldSkipPath("response/fragments/-3/status") { + t.Fatal("expected skip for fragment -3 status") + } +} + +func TestShouldSkipPathRegularContent(t *testing.T) { + if shouldSkipPath("response/content") { + t.Fatal("expected not skip for content path") + } + if shouldSkipPath("response/thinking_content") { + t.Fatal("expected not skip for thinking_content path") + } +} + +// ─── ParseSSEChunkForContent edge cases ────────────────────────────── + +func TestParseSSEChunkForContentNoVField(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts when v is missing, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected type preserved, got %q", nextType) + } +} + +func TestParseSSEChunkForContentSkippedPath(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{ + "p": "response/token_usage", + "v": "some data", + }, false, "text") + if finished || len(parts) > 0 { + t.Fatalf("expected skipped path to produce no output") + } + if nextType != "text" { + t.Fatalf("expected type preserved for skipped path") + } +} + +func TestParseSSEChunkForContentFinishedStatus(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished, got %d", len(parts)) + } +} + +func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "IN_PROGRESS", + }, false, "text") + if finished { + t.Fatal("expected not finished for non-FINISHED status") + } + if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" { + t.Fatalf("expected content for non-FINISHED status, got %#v", parts) + } +} + +func TestParseSSEChunkForContentEmptyStringV(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/content", + "v": "", + }, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty string v, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on empty path with FINISHED value") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished") + } +} + +func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) { + _, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status path with FINISHED value") + } +} + +func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "some thought", + }, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking part on empty path, got %#v", parts) + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } +} + +func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "text content", + }, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text part when currentType=text, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: fragments path with THINK type ───────── + +func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINK", + "content": "深入思考...", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty content, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "UNKNOWN", + "content": "some text", + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for unknown fragment type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": "not an array", + } + parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + // "not an array" should be treated as string value at the end + if len(parts) != 1 || parts[0].Text != "not an array" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{"string item"}, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + // Non-map items in fragment array are skipped; the []any itself is handled later + _ = parts // just checking it doesn't panic +} + +// ─── ParseSSEChunkForContent: response path with nested fragment ───── + +func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINKING", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "text") + if nextType != "thinking" { + t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType) + } +} + +func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if nextType != "text" { + t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: map value with wrapped response ──────── + +func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "response": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "THINK", + "content": "思考...", + }, + map[string]any{ + "type": "RESPONSE", + "content": "回答...", + }, + }, + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text after RESPONSE, got %q", nextType) + } + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "思考..." { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "回答..." { + t.Fatalf("second part mismatch: %#v", parts[1]) + } +} + +func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "直接回答", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" { + t.Fatalf("unexpected parts for direct fragments: %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "CUSTOM", + "content": "custom content", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected partType fallback for unknown type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty fragment content, got %#v", parts) + } +} + +// ─── ParseSSEChunkForContent: fragments/-1/content path ────────────── + +func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments/-1/content", + "v": "继续思考", + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected inherited thinking type, got %#v", parts) + } +} + +// ─── IsCitation edge cases ─────────────────────────────────────────── + +func TestIsCitationWithLeadingWhitespace(t *testing.T) { + if !IsCitation(" [citation:2] text") { + t.Fatal("expected citation true with leading whitespace") + } +} + +func TestIsCitationEmpty(t *testing.T) { + if IsCitation("") { + t.Fatal("expected citation false for empty string") + } +} + +func TestIsCitationSimilarPrefix(t *testing.T) { + if IsCitation("[cite:1] text") { + t.Fatal("expected citation false for [cite: prefix") + } +} + +// ─── extractContentRecursive edge cases ────────────────────────────── + +func TestExtractContentRecursiveFinishedStatus(t *testing.T) { + items := []any{ + map[string]any{"p": "status", "v": "FINISHED"}, + } + parts, finished := extractContentRecursive(items, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts, got %#v", parts) + } +} + +func TestExtractContentRecursiveSkipsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "token_usage", "v": "data"}, + } + parts, finished := extractContentRecursive(items, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for skipped path, got %#v", parts) + } +} + +func TestExtractContentRecursiveContentField(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking type for THINK content, got %#v", parts) + } +} + +func TestExtractContentRecursiveThinkingPath(t *testing.T) { + items := []any{ + map[string]any{"p": "thinking_content", "v": "deep thought"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" { + t.Fatalf("unexpected parts for thinking path: %#v", parts) + } +} + +func TestExtractContentRecursiveContentPath(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for content path, got %#v", parts) + } +} + +func TestExtractContentRecursiveResponsePath(t *testing.T) { + items := []any{ + map[string]any{"p": "response", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for response path, got %#v", parts) + } +} + +func TestExtractContentRecursiveFragmentsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "fragments", "v": "fragment text"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for fragments path, got %#v", parts) + } +} + +func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "thought", "type": "THINKING"}, + map[string]any{"content": "answer", "type": "RESPONSE"}, + "raw string", + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "thought" { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "answer" { + t.Fatalf("second part mismatch: %#v", parts[1]) + } + if parts[2].Type != "text" || parts[2].Text != "raw string" { + t.Fatalf("third part mismatch: %#v", parts[2]) + } +} + +func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "", "type": "RESPONSE"}, + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty nested content, got %#v", parts) + } +} + +func TestExtractContentRecursiveFinishedString(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "FINISHED"}, + } + parts, _ := extractContentRecursive(items, "text") + // "FINISHED" string value on non-status path should be skipped + if len(parts) != 0 { + t.Fatalf("expected FINISHED string to be skipped, got %#v", parts) + } +} + +func TestExtractContentRecursiveNoVField(t *testing.T) { + items := []any{ + map[string]any{"p": "content"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for missing v field, got %#v", parts) + } +} + +func TestExtractContentRecursiveNonMapItem(t *testing.T) { + items := []any{"just a string", 42} + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for non-map items, got %#v", parts) + } +} diff --git a/internal/sse/stream_edge_test.go b/internal/sse/stream_edge_test.go new file mode 100644 index 0000000..927b023 --- /dev/null +++ b/internal/sse/stream_edge_test.go @@ -0,0 +1,177 @@ +package sse + +import ( + "context" + "io" + "strings" + "testing" +) + +func TestStartParsedLinePumpEmptyBody(t *testing.T) { + body := strings.NewReader("") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) != 0 { + t.Fatalf("expected no results for empty body, got %d", len(collected)) + } +} + +func TestStartParsedLinePumpMultipleLines(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"text\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "thinking") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(collected)) + } + // First should be thinking + if collected[0].Parts[0].Type != "thinking" { + t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type) + } + // Last should be stop + last := collected[len(collected)-1] + if !last.Stop { + t.Fatal("expected last result to be stop") + } +} + +func TestStartParsedLinePumpTypeTracking(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "text") + + types := make([]string, 0) + for r := range results { + for _, p := range r.Parts { + types = append(types, p.Type) + } + } + <-done + + // Should have: thinking, thinking, text, text + expected := []string{"thinking", "thinking", "text", "text"} + if len(types) != len(expected) { + t.Fatalf("expected types %v, got %v", expected, types) + } + for i, want := range expected { + if types[i] != want { + t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types) + } + } +} + +func TestStartParsedLinePumpContextCancellation(t *testing.T) { + pr, pw := io.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + results, done := StartParsedLinePump(ctx, pr, false, "text") + + // Write one line to allow it to start + go func() { + _, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n") + // Don't close yet - wait for context cancel + }() + + // Read first result + r := <-results + if !r.Parsed || len(r.Parts) == 0 { + t.Fatalf("expected first parsed result, got %#v", r) + } + + // Cancel context - this will cause the pump to exit on next send + cancel() + // Close the pipe to unblock scanner.Scan() + pw.Close() + + // Drain remaining results + for range results { + } + + err := <-done + // Error may be context.Canceled or nil (if pipe closed first) + if err != nil && err != context.Canceled { + t.Fatalf("expected context.Canceled or nil error, got %v", err) + } +} + +func TestStartParsedLinePumpOnlyDONE(t *testing.T) { + body := strings.NewReader("data: [DONE]\n") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + <-done + + if len(collected) != 1 { + t.Fatalf("expected 1 result, got %d", len(collected)) + } + if !collected[0].Stop { + t.Fatal("expected stop on [DONE]") + } +} + +func TestStartParsedLinePumpNonSSELines(t *testing.T) { + body := strings.NewReader( + "event: update\n" + + ": comment line\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var validCount int + for r := range results { + if r.Parsed && len(r.Parts) > 0 { + validCount++ + } + } + <-done + + if validCount != 1 { + t.Fatalf("expected 1 valid result, got %d", validCount) + } +} + +func TestStartParsedLinePumpThinkingDisabled(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"response\"}\n" + + "data: [DONE]\n", + ) + // With thinking disabled, thinking content should still be emitted but marked differently + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var parts []ContentPart + for r := range results { + parts = append(parts, r.Parts...) + } + <-done + + if len(parts) < 1 { + t.Fatalf("expected at least 1 part, got %d", len(parts)) + } +} diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go new file mode 100644 index 0000000..393aa88 --- /dev/null +++ b/internal/util/util_edge_test.go @@ -0,0 +1,441 @@ +package util + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "ds2api/internal/config" +) + +// ─── EstimateTokens edge cases ─────────────────────────────────────── + +func TestEstimateTokensEmpty(t *testing.T) { + if got := EstimateTokens(""); got != 0 { + t.Fatalf("expected 0 for empty string, got %d", got) + } +} + +func TestEstimateTokensShortASCII(t *testing.T) { + got := EstimateTokens("ab") + if got != 1 { + t.Fatalf("expected 1 for 2 ascii chars, got %d", got) + } +} + +func TestEstimateTokensLongASCII(t *testing.T) { + got := EstimateTokens(strings.Repeat("x", 100)) + if got != 25 { + t.Fatalf("expected 25 for 100 ascii chars, got %d", got) + } +} + +func TestEstimateTokensChinese(t *testing.T) { + got := EstimateTokens("你好世界") + if got < 1 { + t.Fatalf("expected at least 1 token for Chinese text, got %d", got) + } +} + +func TestEstimateTokensMixed(t *testing.T) { + got := EstimateTokens("Hello 你好世界") + if got < 2 { + t.Fatalf("expected at least 2 tokens for mixed text, got %d", got) + } +} + +func TestEstimateTokensSingleByte(t *testing.T) { + got := EstimateTokens("x") + if got != 1 { + t.Fatalf("expected 1 for single char (minimum), got %d", got) + } +} + +func TestEstimateTokensSingleChinese(t *testing.T) { + got := EstimateTokens("你") + if got != 1 { + t.Fatalf("expected 1 for single Chinese char, got %d", got) + } +} + +// ─── ToBool edge cases ─────────────────────────────────────────────── + +func TestToBoolTrue(t *testing.T) { + if !ToBool(true) { + t.Fatal("expected true") + } +} + +func TestToBoolFalse(t *testing.T) { + if ToBool(false) { + t.Fatal("expected false") + } +} + +func TestToBoolNonBool(t *testing.T) { + if ToBool("true") { + t.Fatal("expected false for string 'true'") + } + if ToBool(1) { + t.Fatal("expected false for int 1") + } + if ToBool(nil) { + t.Fatal("expected false for nil") + } +} + +// ─── IntFrom edge cases ───────────────────────────────────────────── + +func TestIntFromFloat64(t *testing.T) { + if got := IntFrom(float64(42.5)); got != 42 { + t.Fatalf("expected 42 for float64(42.5), got %d", got) + } +} + +func TestIntFromInt(t *testing.T) { + if got := IntFrom(int(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromInt64(t *testing.T) { + if got := IntFrom(int64(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromString(t *testing.T) { + if got := IntFrom("42"); got != 0 { + t.Fatalf("expected 0 for string, got %d", got) + } +} + +func TestIntFromNil(t *testing.T) { + if got := IntFrom(nil); got != 0 { + t.Fatalf("expected 0 for nil, got %d", got) + } +} + +// ─── WriteJSON ─────────────────────────────────────────────────────── + +func TestWriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + WriteJSON(rec, 200, map[string]any{"key": "value"}) + if rec.Code != 200 { + t.Fatalf("expected 200, got %d", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected application/json content type, got %q", ct) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode error: %v", err) + } + if body["key"] != "value" { + t.Fatalf("unexpected body: %#v", body) + } +} + +func TestWriteJSONStatusCodes(t *testing.T) { + for _, code := range []int{200, 201, 400, 404, 500} { + rec := httptest.NewRecorder() + WriteJSON(rec, code, map[string]any{"status": code}) + if rec.Code != code { + t.Fatalf("expected %d, got %d", code, rec.Code) + } + } +} + +// ─── MessagesPrepare edge cases ────────────────────────────────────── + +func TestMessagesPrepareEmpty(t *testing.T) { + got := MessagesPrepare(nil) + if got != "" { + t.Fatalf("expected empty for nil messages, got %q", got) + } +} + +func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "World"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") { + t.Fatalf("expected both messages, got %q", got) + } + // Should be merged without <|User|> between them + count := strings.Count(got, "<|User|>") + if count != 0 { + t.Fatalf("expected no User marker for first message pair, got %d occurrences", count) + } +} + +func TestMessagesPrepareAssistantMarkers(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "<|Assistant|>") { + t.Fatalf("expected assistant marker, got %q", got) + } + if !strings.Contains(got, "<|end▁of▁sentence|>") { + t.Fatalf("expected end of sentence marker, got %q", got) + } +} + +func TestMessagesPrepareUnknownRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "unknown_role", "content": "Unknown"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Unknown") { + t.Fatalf("expected unknown role content, got %q", got) + } +} + +func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Look at this: ![alt](https://example.com/img.png)"}, + } + got := MessagesPrepare(messages) + if strings.Contains(got, "![alt]") { + t.Fatalf("expected markdown image to be replaced, got %q", got) + } +} + +func TestMessagesPrepareNilContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": nil}, + } + got := MessagesPrepare(messages) + if got != "null" { + t.Logf("nil content handled as: %q", got) + } +} + +// ─── normalizeContent edge cases ───────────────────────────────────── + +func TestNormalizeContentString(t *testing.T) { + got := normalizeContent("hello") + if got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestNormalizeContentArray(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }) + if got != "line1\nline2" { + t.Fatalf("expected 'line1\\nline2', got %q", got) + } +} + +func TestNormalizeContentArrayWithContentField(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "content": "from-content"}, + }) + if got != "from-content" { + t.Fatalf("expected 'from-content', got %q", got) + } +} + +func TestNormalizeContentArraySkipsImage(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"}, + map[string]any{"type": "text", "text": "caption"}, + }) + if strings.Contains(got, "image") { + t.Fatalf("expected image skipped, got %q", got) + } + if got != "caption" { + t.Fatalf("expected 'caption', got %q", got) + } +} + +func TestNormalizeContentArrayNonMapItems(t *testing.T) { + got := normalizeContent([]any{"string item", 42}) + if got != "" { + t.Fatalf("expected empty for non-map items, got %q", got) + } +} + +func TestNormalizeContentJSON(t *testing.T) { + got := normalizeContent(map[string]any{"key": "value"}) + if !strings.Contains(got, `"key":"value"`) { + t.Fatalf("expected JSON serialized, got %q", got) + } +} + +// ─── ConvertClaudeToDeepSeek edge cases ────────────────────────────── + +func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] == "" { + t.Fatal("expected default model") + } +} + +func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "stop_sequences": []any{"\n\n"}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["stop"] == nil { + t.Fatal("expected stop field from stop_sequences") + } +} + +func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "temperature": 0.7, + "top_p": 0.9, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["temperature"] != 0.7 { + t.Fatalf("expected temperature 0.7, got %v", out["temperature"]) + } + if out["top_p"] != 0.9 { + t.Fatalf("expected top_p 0.9, got %v", out["top_p"]) + } +} + +func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + msgs, _ := out["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("expected 1 message without system, got %d", len(msgs)) + } +} + +func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] != "deepseek-reasoner" { + t.Fatalf("expected opus to use slow mapping, got %q", out["model"]) + } +} + +// ─── FormatOpenAIStreamToolCalls ───────────────────────────────────── + +func TestFormatOpenAIStreamToolCalls(t *testing.T) { + formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{ + {Name: "search", Input: map[string]any{"q": "test"}}, + }) + if len(formatted) != 1 { + t.Fatalf("expected 1, got %d", len(formatted)) + } + fn, _ := formatted[0]["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn) + } + if formatted[0]["index"] != 0 { + t.Fatalf("expected index 0, got %v", formatted[0]["index"]) + } +} + +// ─── ParseToolCalls more edge cases ────────────────────────────────── + +func TestParseToolCallsNoToolNames(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseToolCalls(text, nil) + if len(calls) != 1 { + t.Fatalf("expected 1 call with nil tool names, got %d", len(calls)) + } +} + +func TestParseToolCallsEmptyText(t *testing.T) { + calls := ParseToolCalls("", []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected 0 calls for empty text, got %d", len(calls)) + } +} + +func TestParseToolCallsMultipleTools(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}` + calls := ParseToolCalls(text, []string{"search", "get_weather"}) + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %d", len(calls)) + } +} + +func TestParseToolCallsInputAsString(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}` + calls := ParseToolCalls(text, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Input["q"] != "golang" { + t.Fatalf("expected parsed string input, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsWithFunctionWrapper(t *testing.T) { + text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}` + calls := ParseToolCalls(text, []string{"calc"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Name != "calc" { + t.Fatalf("expected calc, got %q", calls[0].Name) + } +} + +func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) { + fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this." + calls := ParseStandaloneToolCalls(fenced, []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected fenced code block ignored, got %d calls", len(calls)) + } +} + +// ─── looksLikeToolExampleContext ───────────────────────────────────── + +func TestLooksLikeToolExampleContextChinese(t *testing.T) { + if !looksLikeToolExampleContext("下面是示例") { + t.Fatal("expected true for Chinese example context") + } +} + +func TestLooksLikeToolExampleContextEnglish(t *testing.T) { + if !looksLikeToolExampleContext("here is an example of") { + t.Fatal("expected true for English example context") + } +} + +func TestLooksLikeToolExampleContextNone(t *testing.T) { + if looksLikeToolExampleContext("I will call the tool now") { + t.Fatal("expected false for non-example context") + } +} + +func TestLooksLikeToolExampleContextFenced(t *testing.T) { + if !looksLikeToolExampleContext("```json") { + t.Fatal("expected true for fenced code block context") + } +} From ce74b124d287cd493bac242ffd14d97c4fb4b303 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 17:16:57 +0800 Subject: [PATCH 07/52] fix: Apply responsive height to the Trash2 icon on large screens. --- webui/src/components/AccountManager.jsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui/src/components/AccountManager.jsx b/webui/src/components/AccountManager.jsx index 773b84e..3a8fa76 100644 --- a/webui/src/components/AccountManager.jsx +++ b/webui/src/components/AccountManager.jsx @@ -419,7 +419,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch onClick={() => deleteAccount(id)} className="p-1 lg:p-1.5 text-muted-foreground hover:text-destructive hover:bg-destructive/10 rounded-md transition-colors" > - + From 7fc10573aba6f9a6df5123946f86a92a70b7edb8 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 17:24:43 +0800 Subject: [PATCH 08/52] feat: Improve tool sieve to correctly preserve trailing text within the same chunk as a tool call. --- api/helpers/stream-tool-sieve.js | 2 +- api/helpers/stream-tool-sieve.test.js | 12 ++++++ .../adapter/openai/handler_toolcall_test.go | 41 +++++++++++++++++++ internal/adapter/openai/tool_sieve.go | 2 +- 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index 4a713e5..dc40f3a 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -220,7 +220,7 @@ function consumeToolCapture(state, toolNames) { } const prefixPart = captured.slice(0, start); const suffixPart = captured.slice(obj.end); - if (!state.toolNameSent && (hasMeaningfulText(prefixPart) || hasMeaningfulText(suffixPart) || looksLikeToolExampleContext(state.recentTextTail))) { + if (!state.toolNameSent && (hasMeaningfulText(prefixPart) || looksLikeToolExampleContext(state.recentTextTail) || looksLikeToolExampleContext(suffixPart))) { return { ready: true, prefix: captured, diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js index c085436..fea891f 100644 --- a/api/helpers/stream-tool-sieve.test.js +++ b/api/helpers/stream-tool-sieve.test.js @@ -183,3 +183,15 @@ test('sieve still intercepts tool call after leading plain text without suffix', assert.equal(leakedText.includes('我将调用工具。'), true); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); + +test('sieve intercepts tool call and preserves trailing same-chunk text', () => { + const events = runSieve( + ['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'], + ['read_file'], + ); + const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); + const leakedText = collectText(events); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('然后继续解释。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 3cab68c..c987991 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -539,6 +539,47 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { } } +func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + content := strings.Builder{} + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if c, ok := delta["content"].(string); ok { + content.WriteString(c) + } + } + } + got := content.String() + if !strings.Contains(got, "接下来我会继续说明。") { + t.Fatalf("expected trailing plain text to be preserved, got=%q", got) + } + if strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("unexpected raw tool json leak, got=%q", got) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) { h := &Handler{} spaces := strings.Repeat(" ", 200) diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index e5d6b77..b737ff6 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -227,7 +227,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix } prefixPart := captured[:start] suffixPart := captured[end:] - if !state.toolNameSent && (strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" || looksLikeToolExampleContext(state.recentTextTail)) { + if !state.toolNameSent && (strings.TrimSpace(prefixPart) != "" || looksLikeToolExampleContext(state.recentTextTail) || looksLikeToolExampleContext(suffixPart)) { return captured, nil, "", true } parsed := util.ParseStandaloneToolCalls(obj, toolNames) From 0348fa8a22084b2aaa68a1ee16b32904ab9460ee Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 20:39:38 +0800 Subject: [PATCH 09/52] feat: Enhance account identification to support email, mobile, and token-only synthetic IDs across API, UI, and documentation. --- API.en.md | 6 +- API.md | 6 +- internal/admin/handler_accounts.go | 15 +- .../admin/handler_accounts_identifier_test.go | 138 ++++++++++++++++++ internal/admin/handler_config.go | 1 + internal/admin/helpers.go | 31 ++++ webui/src/components/AccountManager.jsx | 37 +++-- webui/src/components/ApiTester.jsx | 18 ++- webui/src/locales/en.json | 1 + webui/src/locales/zh.json | 1 + 10 files changed, 232 insertions(+), 22 deletions(-) create mode 100644 internal/admin/handler_accounts_identifier_test.go diff --git a/API.en.md b/API.en.md index 1203e12..09149b2 100644 --- a/API.en.md +++ b/API.en.md @@ -439,6 +439,7 @@ Returns sanitized config. "keys": ["k1", "k2"], "accounts": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -499,6 +500,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. { "items": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -523,7 +525,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. ### `DELETE /admin/accounts/{identifier}` -`identifier` is email or mobile. +`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:`). **Response**: `{"success": true, "total_accounts": 5}` @@ -553,7 +555,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. | Field | Required | Notes | | --- | --- | --- | -| `identifier` | ✅ | email or mobile | +| `identifier` | ✅ | email / mobile / token-only synthetic id | | `model` | ❌ | default `deepseek-chat` | | `message` | ❌ | if empty, only session creation is tested | diff --git a/API.md b/API.md index f57f0a8..02cbf9b 100644 --- a/API.md +++ b/API.md @@ -439,6 +439,7 @@ data: {"type":"message_stop"} "keys": ["k1", "k2"], "accounts": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -499,6 +500,7 @@ data: {"type":"message_stop"} { "items": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -523,7 +525,7 @@ data: {"type":"message_stop"} ### `DELETE /admin/accounts/{identifier}` -`identifier` 为 email 或 mobile。 +`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:`)。 **响应**:`{"success": true, "total_accounts": 5}` @@ -553,7 +555,7 @@ data: {"type":"message_stop"} | 字段 | 必填 | 说明 | | --- | --- | --- | -| `identifier` | ✅ | email 或 mobile | +| `identifier` | ✅ | email / mobile / token-only 合成标识 | | `model` | ❌ | 默认 `deepseek-chat` | | `message` | ❌ | 空字符串时仅测试会话创建 | diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts.go index b95077d..5cb88cc 100644 --- a/internal/admin/handler_accounts.go +++ b/internal/admin/handler_accounts.go @@ -56,7 +56,14 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { preview = token } } - items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview}) + items = append(items, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": acc.Password != "", + "has_token": token != "", + "token_preview": preview, + }) } writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) } @@ -94,7 +101,7 @@ func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { err := h.Store.Update(func(c *config.Config) error { idx := -1 for i, a := range c.Accounts { - if a.Email == identifier || a.Mobile == identifier { + if accountMatchesIdentifier(a, identifier) { idx = i break } @@ -122,10 +129,10 @@ func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { _ = json.NewDecoder(r.Body).Decode(&req) identifier, _ := req["identifier"].(string) if strings.TrimSpace(identifier) == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"}) + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"}) return } - acc, ok := h.Store.FindAccount(identifier) + acc, ok := findAccountByIdentifier(h.Store, identifier) if !ok { writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"}) return diff --git a/internal/admin/handler_accounts_identifier_test.go b/internal/admin/handler_accounts_identifier_test.go new file mode 100644 index 0000000..591d43a --- /dev/null +++ b/internal/admin/handler_accounts_identifier_test.go @@ -0,0 +1,138 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +func newAdminTestHandler(t *testing.T, raw string) *Handler { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", raw) + t.Setenv("CONFIG_JSON", "") + store := config.LoadStore() + return &Handler{ + Store: store, + Pool: account.NewPool(store), + } +} + +func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"token":"token-only-account"}] + }`) + + req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil) + rec := httptest.NewRecorder() + h.listAccounts(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response failed: %v", err) + } + items, _ := payload["items"].([]any) + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } + first, _ := items[0].(map[string]any) + identifier, _ := first["identifier"].(string) + if identifier == "" { + t.Fatalf("expected non-empty identifier: %#v", first) + } + if !strings.HasPrefix(identifier, "token:") { + t.Fatalf("expected token synthetic identifier, got %q", identifier) + } +} + +func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"token":"token-only-account"}] + }`) + accounts := h.Store.Accounts() + if len(accounts) != 1 { + t.Fatalf("expected 1 account, got %d", len(accounts)) + } + id := accounts[0].Identifier() + if id == "" { + t.Fatal("expected token-only synthetic identifier") + } + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestDeleteAccountSupportsMobileAlias(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"email":"u@example.com","mobile":"13800138000","password":"pwd"}] + }`) + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/13800138000", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[ + {"email":"u@example.com","mobile":"13800138000","password":"pwd"}, + {"token":"token-only-account"} + ] + }`) + + accByMobile, ok := findAccountByIdentifier(h.Store, "13800138000") + if !ok { + t.Fatal("expected find by mobile") + } + if accByMobile.Email != "u@example.com" { + t.Fatalf("unexpected account by mobile: %#v", accByMobile) + } + + tokenOnlyID := "" + for _, acc := range h.Store.Accounts() { + if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" { + tokenOnlyID = acc.Identifier() + break + } + } + if tokenOnlyID == "" { + t.Fatal("expected token-only account identifier") + } + accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID) + if !ok { + t.Fatalf("expected find by token-only id=%q", tokenOnlyID) + } + if accByTokenOnly.Token != "token-only-account" { + t.Fatalf("unexpected token-only account: %#v", accByTokenOnly) + } +} diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go index 7627602..2b672c3 100644 --- a/internal/admin/handler_config.go +++ b/internal/admin/handler_config.go @@ -37,6 +37,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { } } accounts = append(accounts, map[string]any{ + "identifier": acc.Identifier(), "email": acc.Email, "mobile": acc.Mobile, "has_password": strings.TrimSpace(acc.Password) != "", diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index fa75b59..d7d1198 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -81,3 +81,34 @@ func statusOr(v int, d int) int { } return v } + +func accountMatchesIdentifier(acc config.Account, identifier string) bool { + id := strings.TrimSpace(identifier) + if id == "" { + return false + } + if strings.TrimSpace(acc.Email) == id { + return true + } + if strings.TrimSpace(acc.Mobile) == id { + return true + } + return acc.Identifier() == id +} + +func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) { + id := strings.TrimSpace(identifier) + if id == "" { + return config.Account{}, false + } + if acc, ok := store.FindAccount(id); ok { + return acc, true + } + accounts := store.Snapshot().Accounts + for _, acc := range accounts { + if accountMatchesIdentifier(acc, id) { + return acc, true + } + } + return config.Account{}, false +} diff --git a/webui/src/components/AccountManager.jsx b/webui/src/components/AccountManager.jsx index 3a8fa76..7ee3b97 100644 --- a/webui/src/components/AccountManager.jsx +++ b/webui/src/components/AccountManager.jsx @@ -39,6 +39,10 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch const [loadingAccounts, setLoadingAccounts] = useState(false) const apiFetch = authFetch || fetch + const resolveAccountIdentifier = (acc) => { + if (!acc || typeof acc !== 'object') return '' + return String(acc.identifier || acc.email || acc.mobile || '').trim() + } const fetchAccounts = async (targetPage = page) => { setLoadingAccounts(true) @@ -147,9 +151,14 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch } const deleteAccount = async (id) => { + const identifier = String(id || '').trim() + if (!identifier) { + onMessage('error', t('accountManager.invalidIdentifier')) + return + } if (!confirm(t('accountManager.deleteAccountConfirm'))) return try { - const res = await apiFetch(`/admin/accounts/${encodeURIComponent(id)}`, { method: 'DELETE' }) + const res = await apiFetch(`/admin/accounts/${encodeURIComponent(identifier)}`, { method: 'DELETE' }) if (res.ok) { onMessage('success', t('messages.deleted')) fetchAccounts() // 刷新当前页 @@ -163,24 +172,29 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch } const testAccount = async (identifier) => { - setTesting(prev => ({ ...prev, [identifier]: true })) + const accountID = String(identifier || '').trim() + if (!accountID) { + onMessage('error', t('accountManager.invalidIdentifier')) + return + } + setTesting(prev => ({ ...prev, [accountID]: true })) try { const res = await apiFetch('/admin/accounts/test', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ identifier }), + body: JSON.stringify({ identifier: accountID }), }) const data = await res.json() const statusMessage = data.success - ? t('apiTester.testSuccess', { account: identifier, time: data.response_time }) - : `${identifier}: ${data.message}` + ? t('apiTester.testSuccess', { account: accountID, time: data.response_time }) + : `${accountID}: ${data.message}` onMessage(data.success ? 'success' : 'error', statusMessage) fetchAccounts() // 刷新当前页 onRefresh() } catch (e) { onMessage('error', t('accountManager.testFailed', { error: e.message })) } finally { - setTesting(prev => ({ ...prev, [identifier]: false })) + setTesting(prev => ({ ...prev, [accountID]: false })) } } @@ -197,7 +211,12 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch for (let i = 0; i < allAccounts.length; i++) { const acc = allAccounts[i] - const id = acc.email || acc.mobile + const id = resolveAccountIdentifier(acc) + if (!id) { + results.push({ id: '-', success: false, message: t('accountManager.invalidIdentifier') }) + setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] }) + continue + } try { const res = await apiFetch('/admin/accounts/test', { @@ -387,7 +406,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
{t('actions.loading')}
) : accounts.length > 0 ? ( accounts.map((acc, i) => { - const id = acc.email || acc.mobile + const id = resolveAccountIdentifier(acc) return (
@@ -396,7 +415,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch acc.has_token ? "bg-emerald-500 shadow-[0_0_8px_rgba(16,185,129,0.5)]" : "bg-amber-500" )} />
-
{id}
+
{id || '-'}
{acc.has_token ? t('accountManager.sessionActive') : t('accountManager.reauthRequired')} {acc.token_preview && ( diff --git a/webui/src/components/ApiTester.jsx b/webui/src/components/ApiTester.jsx index 7d49982..75af1c0 100644 --- a/webui/src/components/ApiTester.jsx +++ b/webui/src/components/ApiTester.jsx @@ -42,6 +42,10 @@ export default function ApiTester({ config, onMessage, authFetch }) { const apiFetch = authFetch || fetch const accounts = config.accounts || [] + const resolveAccountIdentifier = (acc) => { + if (!acc || typeof acc !== 'object') return '' + return String(acc.identifier || acc.email || acc.mobile || '').trim() + } const configuredKeys = config.keys || [] const trimmedApiKey = apiKey.trim() const defaultKey = configuredKeys[0] || '' @@ -297,11 +301,15 @@ return ( onChange={e => setSelectedAccount(e.target.value)} > - {accounts.map((acc, i) => ( - - ))} + {accounts.map((acc, i) => { + const id = resolveAccountIdentifier(acc) + if (!id) return null + return ( + + ) + })}
diff --git a/webui/src/locales/en.json b/webui/src/locales/en.json index 0daf15f..07610f5 100644 --- a/webui/src/locales/en.json +++ b/webui/src/locales/en.json @@ -86,6 +86,7 @@ "requiredFields": "Password and email/mobile are required.", "deleteKeyConfirm": "Are you sure you want to delete this API key?", "deleteAccountConfirm": "Are you sure you want to delete this account?", + "invalidIdentifier": "Invalid account identifier. Operation aborted.", "testAllConfirm": "Test API connectivity for all accounts?", "testAllCompleted": "Completed: {success}/{total} available", "testFailed": "Test failed: {error}", diff --git a/webui/src/locales/zh.json b/webui/src/locales/zh.json index b405ee4..d0780dd 100644 --- a/webui/src/locales/zh.json +++ b/webui/src/locales/zh.json @@ -86,6 +86,7 @@ "requiredFields": "需要填写密码以及邮箱或手机号", "deleteKeyConfirm": "确定要删除此 API 密钥吗?", "deleteAccountConfirm": "确定要删除此账号吗?", + "invalidIdentifier": "账号标识无效,无法执行操作", "testAllConfirm": "测试所有账号的 API 连通性?", "testAllCompleted": "完成:{success}/{total} 可用", "testFailed": "测试失败: {error}", From 27ecb4b69b5e9b41abc103ad0585aea2d01959db Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 21:42:25 +0800 Subject: [PATCH 10/52] feat: Implement response storage and retrieval, add embeddings API, and enhance tool call extraction logic. --- .github/workflows/release-artifacts.yml | 48 +++++++------------------ 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 504cf7b..67689cc 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -86,11 +86,19 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Extract Docker metadata - id: meta + id: meta_release uses: docker/metadata-action@v5 with: - images: ghcr.io/${{ github.repository }} + images: | + ghcr.io/${{ github.repository }} + cjackhwang/ds2api tags: | type=raw,value=${{ github.event.release.tag_name }} type=raw,value=latest @@ -102,8 +110,8 @@ jobs: file: ./Dockerfile push: true platforms: linux/amd64,linux/arm64 - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + tags: ${{ steps.meta_release.outputs.tags }} + labels: ${{ steps.meta_release.outputs.labels }} - name: Export Docker image archives for release assets run: | @@ -135,35 +143,3 @@ jobs: dist/*.tar.gz dist/*.zip dist/sha256sums.txt - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to GHCR - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract Docker metadata - id: meta - uses: docker/metadata-action@v5 - with: - images: ghcr.io/${{ github.repository }} - tags: | - type=raw,value=${{ github.event.release.tag_name }} - type=raw,value=latest - - - name: Build and Push Docker Image - uses: docker/build-push-action@v6 - with: - context: . - file: ./Dockerfile - push: true - platforms: linux/amd64,linux/arm64 - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} From 3a75b75ae0a8ce5d7f056513ffa7b92507ebfbb9 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 23:06:18 +0800 Subject: [PATCH 11/52] feat: Introduce model alias resolution, enhanced configuration options, and improved OpenAI/Claude adapter handling for responses, embeddings, and tool calls. --- API.en.md | 128 +++++- API.md | 128 +++++- README.MD | 47 +- README.en.md | 47 +- api/chat-stream.js | 4 +- api/helpers/stream-tool-sieve.js | 65 ++- api/helpers/stream-tool-sieve.test.js | 4 +- config.example.json | 26 +- internal/adapter/claude/error_shape_test.go | 35 ++ internal/adapter/claude/handler.go | 52 ++- internal/adapter/openai/embeddings_handler.go | 138 ++++++ internal/adapter/openai/error_shape_test.go | 35 ++ internal/adapter/openai/handler.go | 74 +++- .../adapter/openai/handler_toolcall_test.go | 28 +- internal/adapter/openai/models_route_test.go | 46 ++ internal/adapter/openai/response_store.go | 91 ++++ .../openai/responses_embeddings_test.go | 65 +++ internal/adapter/openai/responses_handler.go | 407 ++++++++++++++++++ internal/adapter/openai/tool_sieve.go | 79 ++-- internal/adapter/openai/vercel_stream.go | 12 +- internal/config/config.go | 118 +++++ internal/config/model_alias_test.go | 44 ++ internal/config/models.go | 113 ++++- internal/server/router.go | 2 +- internal/util/toolcalls.go | 33 +- internal/util/toolcalls_test.go | 7 +- internal/util/util_edge_test.go | 12 - opencode.json.example | 8 +- 28 files changed, 1665 insertions(+), 183 deletions(-) create mode 100644 internal/adapter/claude/error_shape_test.go create mode 100644 internal/adapter/openai/embeddings_handler.go create mode 100644 internal/adapter/openai/error_shape_test.go create mode 100644 internal/adapter/openai/models_route_test.go create mode 100644 internal/adapter/openai/response_store.go create mode 100644 internal/adapter/openai/responses_embeddings_test.go create mode 100644 internal/adapter/openai/responses_handler.go create mode 100644 internal/config/model_alias_test.go diff --git a/API.en.md b/API.en.md index 09149b2..babd1dc 100644 --- a/API.en.md +++ b/API.en.md @@ -28,7 +28,7 @@ This document describes the actual behavior of the current Go codebase. | Base URL | `http://localhost:5001` or your deployment domain | | Default Content-Type | `application/json` | | Health probes | `GET /healthz`, `GET /readyz` | -| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) | +| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) | --- @@ -89,7 +89,11 @@ Two header formats accepted: | GET | `/healthz` | None | Liveness probe | | GET | `/readyz` | None | Readiness probe | | GET | `/v1/models` | None | OpenAI model list | +| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) | | POST | `/v1/chat/completions` | Business | OpenAI chat completions | +| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) | +| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) | +| POST | `/v1/embeddings` | Business | OpenAI Embeddings API | | GET | `/anthropic/v1/models` | None | Claude model list | | POST | `/anthropic/v1/messages` | Business | Claude messages | | POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting | @@ -150,6 +154,15 @@ No auth required. Returns supported models. } ``` +### Model Alias Resolution + +For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy: + +1. Match DeepSeek native model IDs first. +2. Then match exact keys in `model_aliases`. +3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.). +4. If still unmatched, return `invalid_request_error`. + ### `POST /v1/chat/completions` **Headers**: @@ -163,7 +176,7 @@ Content-Type: application/json | Field | Type | Required | Notes | | --- | --- | --- | --- | -| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` | +| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) | | `messages` | array | ✅ | OpenAI-style messages | | `stream` | boolean | ❌ | Default `false` | | `tools` | array | ❌ | Function calling schema | @@ -253,7 +266,63 @@ When `tools` is present, DS2API performs anti-leak handling: } ``` -**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once. +**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`. + +--- + +### `GET /v1/models/{id}` + +No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model. + +### `POST /v1/responses` + +OpenAI Responses-style endpoint, accepting either `input` or `messages`. + +| Field | Type | Required | Notes | +| --- | --- | --- | --- | +| `model` | string | ✅ | Supports native models + alias mapping | +| `input` | string/array/object | ❌ | One of `input` or `messages` is required | +| `messages` | array | ❌ | One of `input` or `messages` is required | +| `instructions` | string | ❌ | Prepended as a system message | +| `stream` | boolean | ❌ | Default `false` | +| `tools` | array | ❌ | Same tool detection/translation policy as chat | + +**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache. + +**Stream (SSE)**: minimal event sequence: + +```text +event: response.created +data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} + +event: response.output_tool_call.delta +data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]} + +event: response.completed +data: {"type":"response.completed","response":{...}} + +data: [DONE] +``` + +### `GET /v1/responses/{response_id}` + +Business auth required. Fetches cached responses created by `POST /v1/responses`. + +> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`). + +### `POST /v1/embeddings` + +Business auth required. Returns OpenAI-compatible embeddings shape. + +| Field | Type | Required | Notes | +| --- | --- | --- | --- | +| `model` | string | ✅ | Supports native models + alias mapping | +| `input` | string/array | ✅ | Supports string, string array, token array | + +> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501. --- @@ -272,7 +341,10 @@ No auth required. {"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"} - ] + ], + "first_id": "claude-opus-4-6", + "last_id": "claude-instant-1.0", + "has_more": false } ``` @@ -288,13 +360,15 @@ Content-Type: application/json anthropic-version: 2023-06-01 ``` +> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent. + **Request body**: | Field | Type | Required | Notes | | --- | --- | --- | --- | | `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs | | `messages` | array | ✅ | Claude-style messages | -| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge | +| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge | | `stream` | boolean | ❌ | Default `false` | | `system` | string | ❌ | Optional system prompt | | `tools` | array | ❌ | Claude tool schema | @@ -684,13 +758,20 @@ Or manual deploy required: ## Error Payloads -Error formats vary by module: +Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope: -| Module | Format | -| --- | --- | -| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` | -| Claude routes | `{"error": {"type": "...", "message": "..."}}` | -| Admin routes | `{"detail": "..."}` | +```json +{ + "error": { + "message": "...", + "type": "invalid_request_error", + "code": "invalid_request", + "param": null + } +} +``` + +Admin routes keep `{"detail":"..."}`. Clients should handle HTTP status code plus `error` / `detail` fields. @@ -732,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### OpenAI Responses (Stream) + +```bash +curl http://localhost:5001/v1/responses \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5-codex", + "input": "Write a hello world in golang", + "stream": true + }' +``` + +### OpenAI Embeddings + +```bash +curl http://localhost:5001/v1/embeddings \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "input": ["first text", "second text"] + }' +``` + ### OpenAI with Search ```bash diff --git a/API.md b/API.md index 02cbf9b..fa07cfa 100644 --- a/API.md +++ b/API.md @@ -28,7 +28,7 @@ | Base URL | `http://localhost:5001` 或你的部署域名 | | 默认 Content-Type | `application/json` | | 健康检查 | `GET /healthz`、`GET /readyz` | -| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`) | +| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) | --- @@ -89,7 +89,11 @@ Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导 | GET | `/healthz` | 无 | 存活探针 | | GET | `/readyz` | 无 | 就绪探针 | | GET | `/v1/models` | 无 | OpenAI 模型列表 | +| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) | | POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 | +| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) | +| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) | +| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 | | GET | `/anthropic/v1/models` | 无 | Claude 模型列表 | | POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 | | POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 | @@ -150,6 +154,15 @@ Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导 } ``` +### 模型 alias 解析策略 + +对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”: + +1. 先匹配 DeepSeek 原生模型。 +2. 再匹配 `model_aliases` 精确映射。 +3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。 +4. 仍未命中则返回 `invalid_request_error`。 + ### `POST /v1/chat/completions` **请求头**: @@ -163,7 +176,7 @@ Content-Type: application/json | 字段 | 类型 | 必填 | 说明 | | --- | --- | --- | --- | -| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` | +| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) | | `messages` | array | ✅ | OpenAI 风格消息数组 | | `stream` | boolean | ❌ | 默认 `false` | | `tools` | array | ❌ | Function Calling 定义 | @@ -253,7 +266,63 @@ data: [DONE] } ``` -**流式**:先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call 带 `index`);否则一次性输出普通文本。 +**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。 + +--- + +### `GET /v1/models/{id}` + +无需鉴权。入参支持 alias(例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。 + +### `POST /v1/responses` + +OpenAI Responses 风格接口,兼容 `input` 或 `messages`。 + +| 字段 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `model` | string | ✅ | 支持原生模型 + alias 自动映射 | +| `input` | string/array/object | ❌ | 与 `messages` 二选一 | +| `messages` | array | ❌ | 与 `input` 二选一 | +| `instructions` | string | ❌ | 自动前置为 system 消息 | +| `stream` | boolean | ❌ | 默认 `false` | +| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 | + +**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。 + +**流式响应(SSE)**:最小事件序列如下。 + +```text +event: response.created +data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} + +event: response.output_tool_call.delta +data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]} + +event: response.completed +data: {"type":"response.completed","response":{...}} + +data: [DONE] +``` + +### `GET /v1/responses/{response_id}` + +需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象。 + +> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。 + +### `POST /v1/embeddings` + +需要业务鉴权。返回 OpenAI Embeddings 兼容结构。 + +| 字段 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `model` | string | ✅ | 支持原生模型 + alias 自动映射 | +| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 | + +> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。 --- @@ -272,7 +341,10 @@ data: [DONE] {"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"} - ] + ], + "first_id": "claude-opus-4-6", + "last_id": "claude-instant-1.0", + "has_more": false } ``` @@ -288,13 +360,15 @@ Content-Type: application/json anthropic-version: 2023-06-01 ``` +> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。 + **请求体**: | 字段 | 类型 | 必填 | 说明 | | --- | --- | --- | --- | | `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID | | `messages` | array | ✅ | Claude 风格消息数组 | -| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 | +| `max_tokens` | number | ❌ | 缺省自动补 `8192`;当前实现不会硬性截断上游输出 | | `stream` | boolean | ❌ | 默认 `false` | | `system` | string | ❌ | 可选系统提示 | | `tools` | array | ❌ | Claude tool 定义 | @@ -684,13 +758,20 @@ data: {"type":"message_stop"} ## 错误响应格式 -不同模块的错误格式略有差异: +兼容路由(`/v1/*`、`/anthropic/*`)统一使用以下结构: -| 模块 | 格式 | -| --- | --- | -| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` | -| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` | -| Admin 接口 | `{"detail": "..."}` | +```json +{ + "error": { + "message": "...", + "type": "invalid_request_error", + "code": "invalid_request", + "param": null + } +} +``` + +Admin 接口保持 `{"detail":"..."}`。 建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。 @@ -732,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### OpenAI Responses(流式) + +```bash +curl http://localhost:5001/v1/responses \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5-codex", + "input": "写一个 golang 的 hello world", + "stream": true + }' +``` + +### OpenAI Embeddings + +```bash +curl http://localhost:5001/v1/embeddings \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "input": ["第一段文本", "第二段文本"] + }' +``` + ### OpenAI 带搜索 ```bash diff --git a/README.MD b/README.MD index 3517a55..261e34a 100644 --- a/README.MD +++ b/README.MD @@ -54,16 +54,27 @@ flowchart LR | 能力 | 说明 | | --- | --- | -| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) | +| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` | | Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` | | 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 | | 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 | | DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 | -| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 | +| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 | | Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 | | WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) | | 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) | +## 平台兼容矩阵 + +| 级别 | 平台 | 当前状态 | +| --- | --- | --- | +| P0 | Codex CLI/SDK(`wire_api=chat` / `wire_api=responses`) | ✅ | +| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ | +| P0 | Vercel AI SDK(openai-compatible) | ✅ | +| P0 | Anthropic SDK(messages) | ✅ | +| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ | +| P2 | MCP 独立桥接层 | 规划中 | + ## 模型支持 ### OpenAI 接口 @@ -196,6 +207,7 @@ cp opencode.json.example opencode.json 3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。 > 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。 +> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。 ## 配置说明 @@ -216,6 +228,24 @@ cp opencode.json.example opencode.json "token": "" } ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" @@ -226,6 +256,11 @@ cp opencode.json.example opencode.json - `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer ` 鉴权 - `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录 - `token`:留空则首次请求时自动登录获取;也可预填已有 token +- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射 +- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出) +- `toolcall`:固定采用特征匹配 + 高置信早发策略 +- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL +- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`) - `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型 ### 环境变量 @@ -281,10 +316,10 @@ cp opencode.json.example opencode.json 当请求中带 `tools` 时,DS2API 会做防泄漏处理: -1. `stream=true` 时先**缓冲**正文片段 -2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本 -3. 若最终不是工具调用 → 一次性输出普通文本 -4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式 +1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发) +2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls` +3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content` +4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出 ## 项目结构 diff --git a/README.en.md b/README.en.md index d1a91a1..5d2f326 100644 --- a/README.en.md +++ b/README.en.md @@ -54,16 +54,27 @@ flowchart LR | Capability | Details | | --- | --- | -| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) | +| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` | | Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` | | Multi-account rotation | Auto token refresh, email/mobile dual login | | Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency | | DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency | -| Tool Calling | Anti-leak handling: auto buffer, detect, structured output | +| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output | | Admin API | Config management, account testing/batch test, import/export, Vercel sync | | WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) | | Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) | +## Platform Compatibility Matrix + +| Tier | Platform | Status | +| --- | --- | --- | +| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ | +| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ | +| P0 | Vercel AI SDK (openai-compatible) | ✅ | +| P0 | Anthropic SDK (messages) | ✅ | +| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ | +| P2 | MCP standalone bridge | Planned | + ## Model Support ### OpenAI Endpoint @@ -196,6 +207,7 @@ cp opencode.json.example opencode.json 3. Start OpenCode CLI in the project directory (run `opencode` using your installed method). > Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example. +> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths. ## Configuration @@ -216,6 +228,24 @@ cp opencode.json.example opencode.json "token": "" } ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" @@ -226,6 +256,11 @@ cp opencode.json.example opencode.json - `keys`: API access keys; clients authenticate via `Authorization: Bearer ` - `accounts`: DeepSeek account list, supports `email` or `mobile` login - `token`: Leave empty for auto-login on first request; or pre-fill an existing token +- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models +- `compat.wide_input_strict_output`: Keep `true` (current default policy) +- `toolcall`: Fixed to feature matching + high-confidence early emit +- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}` +- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in) - `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models ### Environment Variables @@ -281,10 +316,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency) When `tools` is present in the request, DS2API performs anti-leak handling: -1. With `stream=true`, DS2API **buffers** text deltas first -2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked -3. If no tool call → buffered text is emitted at once -4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads +1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored) +2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately +3. Confirmed toolcall JSON fragments are never leaked into `delta.content` +4. Natural language before/after toolcalls keeps original order, with incremental argument output supported ## Project Structure diff --git a/api/chat-stream.js b/api/chat-stream.js index 309c473..680651d 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -7,7 +7,7 @@ const { createToolSieveState, processToolSieveChunk, flushToolSieve, - parseStandaloneToolCalls, + parseToolCalls, formatOpenAIStreamToolCalls, } = require('./helpers/stream-tool-sieve'); @@ -199,7 +199,7 @@ module.exports = async function handler(req, res) { await releaseLease(); return; } - const detected = parseStandaloneToolCalls(outputText, toolNames); + const detected = parseToolCalls(outputText, toolNames); if (detected.length > 0 && !toolCallsEmitted) { toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index dc40f3a..44e31cd 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -28,7 +28,6 @@ function createToolSieveState() { pending: '', capture: '', capturing: false, - hasMeaningfulText: false, recentTextTail: '', toolNameSent: false, toolName: '', @@ -192,12 +191,21 @@ function findToolSegmentStart(s) { return -1; } const lower = s.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return -1; + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + const keyRel = lower.indexOf('tool_calls', offset); + if (keyRel < 0) { + return -1; + } + const keyIdx = keyRel; + const start = s.slice(0, keyIdx).lastIndexOf('{'); + const candidateStart = start >= 0 ? start : keyIdx; + if (!insideCodeFence(s.slice(0, candidateStart))) { + return candidateStart; + } + offset = keyIdx + 'tool_calls'.length; } - const start = s.slice(0, keyIdx).lastIndexOf('{'); - return start >= 0 ? start : keyIdx; } function consumeToolCapture(state, toolNames) { @@ -220,7 +228,7 @@ function consumeToolCapture(state, toolNames) { } const prefixPart = captured.slice(0, start); const suffixPart = captured.slice(obj.end); - if (!state.toolNameSent && (hasMeaningfulText(prefixPart) || looksLikeToolExampleContext(state.recentTextTail) || looksLikeToolExampleContext(suffixPart))) { + if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { return { ready: true, prefix: captured, @@ -283,7 +291,10 @@ function buildIncrementalToolDeltas(state) { return []; } const start = captured.slice(0, keyIdx).lastIndexOf('{'); - if (start < 0 || hasMeaningfulText(captured.slice(0, start))) { + if (start < 0) { + return []; + } + if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) { return []; } const callStart = findFirstToolCallObjectStart(captured, keyIdx); @@ -621,7 +632,11 @@ function parseToolCalls(text, toolNames) { if (!toStringSafe(text)) { return []; } - const candidates = buildToolCallCandidates(text); + const sanitized = stripFencedCodeBlocks(text); + if (!toStringSafe(sanitized)) { + return []; + } + const candidates = buildToolCallCandidates(sanitized); let parsed = []; for (const c of candidates) { parsed = parseToolCallsPayload(c); @@ -635,11 +650,22 @@ function parseToolCalls(text, toolNames) { return filterToolCalls(parsed, toolNames); } +function stripFencedCodeBlocks(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return ''; + } + return t.replace(/```[\s\S]*?```/g, ' '); +} + function parseStandaloneToolCalls(text, toolNames) { const trimmed = toStringSafe(text); if (!trimmed) { return []; } + if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) { + return []; + } if (looksLikeToolExampleContext(trimmed)) { return []; } @@ -852,7 +878,6 @@ function noteText(state, text) { if (!state || !hasMeaningfulText(text)) { return; } - state.hasMeaningfulText = true; state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); } @@ -870,22 +895,16 @@ function appendTail(prev, next, max) { } function looksLikeToolExampleContext(text) { - const t = toStringSafe(text).toLowerCase(); + return insideCodeFence(text); +} + +function insideCodeFence(text) { + const t = typeof text === 'string' ? text : ''; if (!t) { return false; } - const cues = [ - '示例', - '例子', - 'for example', - 'example', - 'demo', - '请勿执行', - '不要执行', - 'do not execute', - '```', - ]; - return cues.some((cue) => t.includes(cue)); + const ticks = (t.match(/```/g) || []).length; + return ticks % 2 === 1; } function hasMeaningfulText(text) { diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js index fea891f..7f532f1 100644 --- a/api/helpers/stream-tool-sieve.test.js +++ b/api/helpers/stream-tool-sieve.test.js @@ -69,9 +69,7 @@ test('parseToolCalls supports fenced json and function.arguments string payload' '```', ].join('\n'); const calls = parseToolCalls(text, ['read_file']); - assert.equal(calls.length, 1); - assert.equal(calls[0].name, 'read_file'); - assert.deepEqual(calls[0].input, { path: 'README.md' }); + assert.equal(calls.length, 0); }); test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => { diff --git a/config.example.json b/config.example.json index 7614e77..97161f7 100644 --- a/config.example.json +++ b/config.example.json @@ -24,5 +24,27 @@ "password": "your-password-3", "token": "" } - ] -} \ No newline at end of file + ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, + "claude_model_mapping": { + "fast": "deepseek-chat", + "slow": "deepseek-reasoner" + } +} diff --git a/internal/adapter/claude/error_shape_test.go b/internal/adapter/claude/error_shape_test.go new file mode 100644 index 0000000..910fce8 --- /dev/null +++ b/internal/adapter/claude/error_shape_test.go @@ -0,0 +1,35 @@ +package claude + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) { + rec := httptest.NewRecorder() + writeClaudeError(rec, http.StatusUnauthorized, "bad token") + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rec.Code) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + errObj, _ := body["error"].(map[string]any) + if errObj["message"] != "bad token" { + t.Fatalf("unexpected message: %v", errObj["message"]) + } + if errObj["type"] != "invalid_request_error" { + t.Fatalf("unexpected type: %v", errObj["type"]) + } + if errObj["code"] != "authentication_failed" { + t.Fatalf("unexpected code: %v", errObj["code"]) + } + if _, ok := errObj["param"]; !ok { + t.Fatal("expected param field") + } +} + diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index b9ecd27..a7d3431 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -43,6 +43,9 @@ func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { } func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { + if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { + r.Header.Set("anthropic-version", "2023-06-01") + } a, err := h.Auth.Determine(r) if err != nil { status := http.StatusUnauthorized @@ -50,22 +53,25 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { if err == auth.ErrNoAccount { status = http.StatusTooManyRequests } - writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}}) + writeClaudeError(w, status, detail) return } defer h.Auth.Release(a) var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}}) + writeClaudeError(w, http.StatusBadRequest, "invalid json") return } model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if model == "" || len(messagesRaw) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}}) + writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } + if _, ok := req["max_tokens"]; !ok { + req["max_tokens"] = 8192 + } normalized := normalizeClaudeMessages(messagesRaw) payload := cloneMap(req) @@ -86,12 +92,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}}) + writeClaudeError(w, http.StatusUnauthorized, "invalid token.") return } pow, err := h.DS.GetPow(r.Context(), a, 3) if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}}) + writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") return } requestPayload := map[string]any{ @@ -104,13 +110,13 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { } resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}}) + writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") return } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) + writeClaudeError(w, http.StatusInternalServerError, string(body)) return } @@ -162,20 +168,20 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { a, err := h.Auth.Determine(r) if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()}) + writeClaudeError(w, http.StatusUnauthorized, err.Error()) return } defer h.Auth.Release(a) var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"}) + writeClaudeError(w, http.StatusBadRequest, "invalid json") return } model, _ := req["model"].(string) messages, _ := req["messages"].([]any) if model == "" || len(messages) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."}) + writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } inputTokens := 0 @@ -206,7 +212,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) + writeClaudeError(w, http.StatusInternalServerError, string(body)) return } @@ -241,6 +247,8 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ "error": map[string]any{ "type": "api_error", "message": msg, + "code": "internal_error", + "param": nil, }, }) } @@ -492,6 +500,28 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ } } +func writeClaudeError(w http.ResponseWriter, status int, message string) { + code := "invalid_request" + switch status { + case http.StatusUnauthorized: + code = "authentication_failed" + case http.StatusTooManyRequests: + code = "rate_limit_exceeded" + case http.StatusNotFound: + code = "not_found" + case http.StatusInternalServerError: + code = "internal_error" + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "type": "invalid_request_error", + "message": message, + "code": code, + "param": nil, + }, + }) +} + func normalizeClaudeMessages(messages []any) []any { out := make([]any, 0, len(messages)) for _, m := range messages { diff --git a/internal/adapter/openai/embeddings_handler.go b/internal/adapter/openai/embeddings_handler.go new file mode 100644 index 0000000..ff61be0 --- /dev/null +++ b/internal/adapter/openai/embeddings_handler.go @@ -0,0 +1,138 @@ +package openai + +import ( + "crypto/sha256" + "encoding/binary" + "encoding/json" + "fmt" + "net/http" + "strings" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.") + return + } + if _, ok := config.ResolveModel(h.Store, model); !ok { + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) + return + } + + inputs := extractEmbeddingInputs(req["input"]) + if len(inputs) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.") + return + } + + provider := "" + if h.Store != nil { + provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider())) + } + if provider == "" { + writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.") + return + } + switch provider { + case "mock", "deterministic", "builtin": + // supported local deterministic provider + default: + writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider)) + return + } + + data := make([]map[string]any, 0, len(inputs)) + totalTokens := 0 + for i, input := range inputs { + totalTokens += util.EstimateTokens(input) + data = append(data, map[string]any{ + "object": "embedding", + "index": i, + "embedding": deterministicEmbedding(input), + }) + } + writeJSON(w, http.StatusOK, map[string]any{ + "object": "list", + "data": data, + "model": model, + "usage": map[string]any{ + "prompt_tokens": totalTokens, + "total_tokens": totalTokens, + }, + }) +} + +func extractEmbeddingInputs(raw any) []string { + switch v := raw.(type) { + case string: + s := strings.TrimSpace(v) + if s == "" { + return nil + } + return []string{s} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + switch iv := item.(type) { + case string: + s := strings.TrimSpace(iv) + if s != "" { + out = append(out, s) + } + case []any: + // Token array input support: convert to stable string form. + out = append(out, fmt.Sprintf("%v", iv)) + default: + s := strings.TrimSpace(fmt.Sprintf("%v", iv)) + if s != "" { + out = append(out, s) + } + } + } + return out + default: + return nil + } +} + +func deterministicEmbedding(input string) []float64 { + // Keep response shape stable without external dependencies. + const dims = 64 + out := make([]float64, dims) + seed := sha256.Sum256([]byte(input)) + buf := seed[:] + for i := 0; i < dims; i++ { + if len(buf) < 4 { + next := sha256.Sum256(buf) + buf = next[:] + } + v := binary.BigEndian.Uint32(buf[:4]) + buf = buf[4:] + // map [0, 2^32) -> [-1, 1] + out[i] = (float64(v)/2147483647.5 - 1.0) + } + return out +} diff --git a/internal/adapter/openai/error_shape_test.go b/internal/adapter/openai/error_shape_test.go new file mode 100644 index 0000000..c169e04 --- /dev/null +++ b/internal/adapter/openai/error_shape_test.go @@ -0,0 +1,35 @@ +package openai + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) { + rec := httptest.NewRecorder() + writeOpenAIError(rec, http.StatusBadRequest, "invalid input") + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rec.Code) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + errObj, _ := body["error"].(map[string]any) + if errObj["message"] != "invalid input" { + t.Fatalf("unexpected message: %v", errObj["message"]) + } + if errObj["type"] != "invalid_request_error" { + t.Fatalf("unexpected type: %v", errObj["type"]) + } + if errObj["code"] != "invalid_request" { + t.Fatalf("unexpected code: %v", errObj["code"]) + } + if _, ok := errObj["param"]; !ok { + t.Fatal("expected param field") + } +} + diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 4de28b7..a2a1c4d 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -31,6 +31,8 @@ type Handler struct { leaseMu sync.Mutex streamLeases map[string]streamLease + responsesMu sync.Mutex + responses *responseStore } type streamLease struct { @@ -40,13 +42,27 @@ type streamLease struct { func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/v1/models", h.ListModels) + r.Get("/v1/models/{model_id}", h.GetModel) r.Post("/v1/chat/completions", h.ChatCompletions) + r.Post("/v1/responses", h.Responses) + r.Get("/v1/responses/{response_id}", h.GetResponseByID) + r.Post("/v1/embeddings", h.Embeddings) } func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) } +func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) { + modelID := strings.TrimSpace(chi.URLParam(r, "model_id")) + model, ok := config.OpenAIModelByID(h.Store, modelID) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Model not found.") + return + } + writeJSON(w, http.StatusOK, model) +} + func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { if isVercelStreamReleaseRequest(r) { h.handleVercelStreamRelease(w, r) @@ -81,11 +97,16 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) + resolvedModel, ok := config.ResolveModel(h.Store, model) if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) return } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + responseModel := strings.TrimSpace(model) + if responseModel == "" { + responseModel = resolvedModel + } finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) @@ -111,16 +132,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { "thinking_enabled": thinkingEnabled, "search_enabled": searchEnabled, } + applyOpenAIChatPassThrough(req, payload) resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") return } if util.ToBool(req["stream"]) { - h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + h.handleStream(w, r, resp, sessionID, responseModel, finalPrompt, thinkingEnabled, searchEnabled, toolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, toolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, responseModel, finalPrompt, thinkingEnabled, toolNames) } func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { @@ -135,7 +157,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - detected := util.ParseStandaloneToolCalls(finalText, toolNames) + detected := util.ParseToolCalls(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} if thinkingEnabled && finalThinking != "" { @@ -222,7 +244,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt finalize := func(finishReason string) { finalThinking := thinking.String() finalText := text.String() - detected := util.ParseStandaloneToolCalls(finalText, toolNames) + detected := util.ParseToolCalls(finalText, toolNames) if len(detected) > 0 && !toolCallsEmitted { finishReason = "tool_calls" delta := map[string]any{ @@ -497,6 +519,8 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string) { "error": map[string]any{ "message": message, "type": openAIErrorType(status), + "code": openAIErrorCode(status), + "param": nil, }, }) } @@ -520,3 +544,41 @@ func openAIErrorType(status int) string { return "invalid_request_error" } } + +func openAIErrorCode(status int) string { + switch status { + case http.StatusBadRequest: + return "invalid_request" + case http.StatusUnauthorized: + return "authentication_failed" + case http.StatusForbidden: + return "forbidden" + case http.StatusTooManyRequests: + return "rate_limit_exceeded" + case http.StatusNotFound: + return "not_found" + case http.StatusServiceUnavailable: + return "service_unavailable" + default: + if status >= 500 { + return "internal_error" + } + return "invalid_request" + } +} + +func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { + for _, k := range []string{ + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "presence_penalty", + "frequency_penalty", + "stop", + } { + if v, ok := req[k]; ok { + payload[k] = v + } + } +} diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index c987991..dd2bb0f 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -210,7 +210,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { } } -func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { +func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"下面是示例:"}`, @@ -228,16 +228,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { out := decodeJSONBody(t, rec.Body.String()) choices, _ := out["choices"].([]any) choice, _ := choices[0].(map[string]any) - if choice["finish_reason"] != "stop" { - t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + if choice["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) } msg, _ := choice["message"].(map[string]any) - if _, ok := msg["tool_calls"]; ok { - t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"]) + toolCalls, _ := msg["tool_calls"].([]any) + if len(toolCalls) == 0 { + t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"]) } - content, _ := msg["content"].(string) - if !strings.Contains(content, "示例") || !strings.Contains(content, `"tool_calls"`) { - t.Fatalf("expected embedded example to pass through as text, got %q", content) + if msg["content"] != nil { + t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"]) } } @@ -471,8 +471,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if streamHasToolCallsDelta(frames) { - t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -489,11 +489,11 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } - if !strings.Contains(got, `"tool_calls"`) { - t.Fatalf("expected mixed stream to preserve embedded tool_calls example text, got=%q", got) + if strings.Contains(strings.ToLower(got), `"tool_calls"`) { + t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got) } - if streamFinishReason(frames) != "stop" { - t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/models_route_test.go b/internal/adapter/openai/models_route_test.go new file mode 100644 index 0000000..1ba3382 --- /dev/null +++ b/internal/adapter/openai/models_route_test.go @@ -0,0 +1,46 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestGetModelRouteDirectAndAlias(t *testing.T) { + h := &Handler{} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("direct", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("alias", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String()) + } + }) +} + +func TestGetModelRouteNotFound(t *testing.T) { + h := &Handler{} + r := chi.NewRouter() + RegisterRoutes(r, h) + + req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/response_store.go b/internal/adapter/openai/response_store.go new file mode 100644 index 0000000..4f51dfa --- /dev/null +++ b/internal/adapter/openai/response_store.go @@ -0,0 +1,91 @@ +package openai + +import ( + "sync" + "time" +) + +type storedResponse struct { + Value map[string]any + ExpiresAt time.Time +} + +type responseStore struct { + mu sync.Mutex + ttl time.Duration + items map[string]storedResponse +} + +func newResponseStore(ttl time.Duration) *responseStore { + if ttl <= 0 { + ttl = 15 * time.Minute + } + return &responseStore{ + ttl: ttl, + items: make(map[string]storedResponse), + } +} + +func (s *responseStore) put(id string, value map[string]any) { + if s == nil || id == "" || value == nil { + return + } + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked(now) + s.items[id] = storedResponse{ + Value: cloneAnyMap(value), + ExpiresAt: now.Add(s.ttl), + } +} + +func (s *responseStore) get(id string) (map[string]any, bool) { + if s == nil || id == "" { + return nil, false + } + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked(now) + item, ok := s.items[id] + if !ok { + return nil, false + } + return cloneAnyMap(item.Value), true +} + +func (s *responseStore) sweepLocked(now time.Time) { + for k, v := range s.items { + if now.After(v.ExpiresAt) { + delete(s.items, k) + } + } +} + +func cloneAnyMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func (h *Handler) getResponseStore() *responseStore { + if h == nil { + return nil + } + h.responsesMu.Lock() + defer h.responsesMu.Unlock() + if h.responses == nil { + ttl := 15 * time.Minute + if h.Store != nil { + ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second + } + h.responses = newResponseStore(ttl) + } + return h.responses +} diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go new file mode 100644 index 0000000..b23597d --- /dev/null +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -0,0 +1,65 @@ +package openai + +import ( + "testing" + "time" +) + +func TestNormalizeResponsesInputAsMessagesString(t *testing.T) { + msgs := normalizeResponsesInputAsMessages("hello") + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "user" || m["content"] != "hello" { + t.Fatalf("unexpected message: %#v", m) + } +} + +func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) { + req := map[string]any{ + "model": "gpt-4.1", + "input": "ping", + "instructions": "system text", + } + msgs := responsesMessagesFromRequest(req) + if len(msgs) != 2 { + t.Fatalf("expected two messages, got %d", len(msgs)) + } + sys, _ := msgs[0].(map[string]any) + if sys["role"] != "system" { + t.Fatalf("unexpected first message: %#v", sys) + } +} + +func TestExtractEmbeddingInputs(t *testing.T) { + got := extractEmbeddingInputs([]any{"a", "b"}) + if len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Fatalf("unexpected inputs: %#v", got) + } +} + +func TestDeterministicEmbeddingStable(t *testing.T) { + a := deterministicEmbedding("hello") + b := deterministicEmbedding("hello") + if len(a) != 64 || len(b) != 64 { + t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b)) + } + for i := range a { + if a[i] != b[i] { + t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i]) + } + } +} + +func TestResponseStorePutGet(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("resp_1", map[string]any{"id": "resp_1"}) + got, ok := st.get("resp_1") + if !ok { + t.Fatal("expected stored response") + } + if got["id"] != "resp_1" { + t.Fatalf("unexpected response payload: %#v", got) + } +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go new file mode 100644 index 0000000..8fbb132 --- /dev/null +++ b/internal/adapter/openai/responses_handler.go @@ -0,0 +1,407 @@ +package openai + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { + id := strings.TrimSpace(chi.URLParam(r, "response_id")) + if id == "" { + writeOpenAIError(w, http.StatusBadRequest, "response_id is required.") + return + } + st := h.getResponseStore() + item, ok := st.get(id) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Response not found.") + return + } + writeJSON(w, http.StatusOK, item) +} + +func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + r = r.WithContext(auth.WithAuth(r.Context(), a)) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.") + return + } + resolvedModel, ok := config.ResolveModel(h.Store, model) + if !ok { + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) + return + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + messagesRaw := responsesMessagesFromRequest(req) + if len(messagesRaw) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "Request must include 'input' or 'messages'.") + return + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": finalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": thinkingEnabled, + "search_enabled": searchEnabled, + } + applyOpenAIChatPassThrough(req, payload) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if util.ToBool(req["stream"]) { + h.handleResponsesStream(w, r, resp, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + return + } + h.handleResponsesNonStream(w, resp, responseID, model, finalPrompt, thinkingEnabled, toolNames) +} + +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + result := sse.CollectStream(resp, thinkingEnabled, true) + responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + h.getResponseStore().put(responseID, responseObj) + writeJSON(w, http.StatusOK, responseObj) +} + +func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + canFlush := rc.Flush() == nil + + sendEvent := func(event string, payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = w.Write([]byte("event: " + event + "\n")) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(b) + _, _ = w.Write([]byte("\n\n")) + if canFlush { + _ = rc.Flush() + } + } + + sendEvent("response.created", map[string]any{ + "type": "response.created", + "id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + }) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) + bufferToolContent := len(toolNames) > 0 + var sieve toolStreamSieveState + thinking := strings.Builder{} + text := strings.Builder{} + toolCallsEmitted := false + streamToolCallIDs := map[int]string{} + + finalize := func() { + finalThinking := thinking.String() + finalText := text.String() + if bufferToolContent { + for _, evt := range flushToolSieve(&sieve, toolNames) { + if evt.Content != "" { + finalText += evt.Content + sendEvent("response.output_text.delta", map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": evt.Content, + }) + } + if len(evt.ToolCalls) > 0 { + toolCallsEmitted = true + sendEvent("response.output_tool_call.done", map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), + }) + } + } + } + obj := buildResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames) + if toolCallsEmitted { + obj["status"] = "completed" + } + h.getResponseStore().put(responseID, obj) + sendEvent("response.completed", map[string]any{ + "type": "response.completed", + "response": obj, + }) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + if canFlush { + _ = rc.Flush() + } + } + + for { + select { + case <-r.Context().Done(): + return + case parsed, ok := <-parsedLines: + if !ok { + _ = <-done + finalize() + return + } + if !parsed.Parsed { + continue + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + finalize() + return + } + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Type == "thinking" { + if !thinkingEnabled { + continue + } + thinking.WriteString(p.Text) + sendEvent("response.reasoning.delta", map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "delta": p.Text, + }) + continue + } + text.WriteString(p.Text) + if !bufferToolContent { + sendEvent("response.output_text.delta", map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": p.Text, + }) + continue + } + for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) { + if evt.Content != "" { + sendEvent("response.output_text.delta", map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": evt.Content, + }) + } + if len(evt.ToolCallDeltas) > 0 { + toolCallsEmitted = true + sendEvent("response.output_tool_call.delta", map[string]any{ + "type": "response.output_tool_call.delta", + "id": responseID, + "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), + }) + } + if len(evt.ToolCalls) > 0 { + toolCallsEmitted = true + sendEvent("response.output_tool_call.done", map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), + }) + } + } + } + } + } +} + +func buildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + output := make([]any, 0, 2) + if len(detected) > 0 { + toolCalls := make([]any, 0, len(detected)) + for _, tc := range detected { + toolCalls = append(toolCalls, map[string]any{ + "type": "tool_call", + "name": tc.Name, + "arguments": tc.Input, + }) + } + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": toolCalls, + }) + } else { + content := []any{ + map[string]any{ + "type": "output_text", + "text": finalText, + }, + } + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": finalText, + "usage": map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + }, + } +} + +func responsesMessagesFromRequest(req map[string]any) []any { + if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + if rawInput, ok := req["input"]; ok { + if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + } + return nil +} + +func prependInstructionMessage(messages []any, instructions any) []any { + sys, _ := instructions.(string) + sys = strings.TrimSpace(sys) + if sys == "" { + return messages + } + out := make([]any, 0, len(messages)+1) + out = append(out, map[string]any{"role": "system", "content": sys}) + out = append(out, messages...) + return out +} + +func normalizeResponsesInputAsMessages(input any) []any { + switch v := input.(type) { + case string: + if strings.TrimSpace(v) == "" { + return nil + } + return []any{map[string]any{"role": "user", "content": v}} + case []any: + if len(v) == 0 { + return nil + } + // If caller already provides role-shaped items, keep as-is. + if first, ok := v[0].(map[string]any); ok { + if _, hasRole := first["role"]; hasRole { + return v + } + } + parts := make([]string, 0, len(v)) + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + parts = append(parts, txt) + continue + } + } + } + if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { + parts = append(parts, s) + } + } + if len(parts) == 0 { + return nil + } + return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}} + case map[string]any: + if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { + return []any{map[string]any{"role": "user", "content": txt}} + } + if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" { + return []any{map[string]any{"role": "user", "content": content}} + } + } + return nil +} diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index b737ff6..fd7222b 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -7,17 +7,16 @@ import ( ) type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool - hasMeaningfulText bool - recentTextTail string - toolNameSent bool - toolName string - toolArgsStart int - toolArgsSent int - toolArgsString bool - toolArgsDone bool + pending strings.Builder + capture strings.Builder + capturing bool + recentTextTail string + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool } type toolStreamEvent struct { @@ -197,14 +196,22 @@ func findToolSegmentStart(s string) int { return -1 } lower := strings.ToLower(s) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - return -1 + offset := 0 + for { + keyRel := strings.Index(lower[offset:], "tool_calls") + if keyRel < 0 { + return -1 + } + keyIdx := offset + keyRel + start := strings.LastIndex(s[:keyIdx], "{") + if start < 0 { + start = keyIdx + } + if !insideCodeFence(s[:start]) { + return start + } + offset = keyIdx + len("tool_calls") } - if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 { - return start - } - return keyIdx } func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { @@ -227,7 +234,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix } prefixPart := captured[:start] suffixPart := captured[end:] - if !state.toolNameSent && (strings.TrimSpace(prefixPart) != "" || looksLikeToolExampleContext(state.recentTextTail) || looksLikeToolExampleContext(suffixPart)) { + if insideCodeFence(state.recentTextTail + prefixPart) { return captured, nil, "", true } parsed := util.ParseStandaloneToolCalls(obj, toolNames) @@ -293,16 +300,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { if captured == "" { return nil } - if looksLikeToolExampleContext(state.recentTextTail) { - return nil - } lower := strings.ToLower(captured) keyIdx := strings.Index(lower, "tool_calls") if keyIdx < 0 { return nil } start := strings.LastIndex(captured[:keyIdx], "{") - if start < 0 || strings.TrimSpace(captured[:start]) != "" { + if start < 0 { + return nil + } + if insideCodeFence(state.recentTextTail + captured[:start]) { return nil } callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) @@ -612,7 +619,6 @@ func (s *toolStreamSieveState) noteText(content string) { if strings.TrimSpace(content) == "" { return } - s.hasMeaningfulText = true s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) } @@ -628,25 +634,12 @@ func appendTail(prev, next string, max int) string { } func looksLikeToolExampleContext(text string) bool { - t := strings.ToLower(strings.TrimSpace(text)) - if t == "" { + return insideCodeFence(text) +} + +func insideCodeFence(text string) bool { + if text == "" { return false } - cues := []string{ - "示例", - "例子", - "for example", - "example", - "demo", - "请勿执行", - "不要执行", - "do not execute", - "```", - } - for _, cue := range cues { - if strings.Contains(t, cue) { - return true - } - } - return false + return strings.Count(text, "```")%2 == 1 } diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 85c9cd8..be8a590 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -62,11 +62,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) + resolvedModel, ok := config.ResolveModel(h.Store, model) if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) return } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + responseModel := strings.TrimSpace(model) + if responseModel == "" { + responseModel = resolvedModel + } finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) @@ -97,6 +102,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque "thinking_enabled": thinkingEnabled, "search_enabled": searchEnabled, } + applyOpenAIChatPassThrough(req, payload) leaseID := h.holdStreamLease(a) if leaseID == "" { writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") @@ -106,7 +112,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusOK, map[string]any{ "session_id": sessionID, "lease_id": leaseID, - "model": model, + "model": responseModel, "final_prompt": finalPrompt, "thinking_enabled": thinkingEnabled, "search_enabled": searchEnabled, diff --git a/internal/config/config.go b/internal/config/config.go index b4058c6..d583159 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -62,11 +62,33 @@ type Config struct { Accounts []Account `json:"accounts,omitempty"` ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` + ModelAliases map[string]string `json:"model_aliases,omitempty"` + Compat CompatConfig `json:"compat,omitempty"` + Toolcall ToolcallConfig `json:"toolcall,omitempty"` + Responses ResponsesConfig `json:"responses,omitempty"` + Embeddings EmbeddingsConfig `json:"embeddings,omitempty"` VercelSyncHash string `json:"_vercel_sync_hash,omitempty"` VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"` AdditionalFields map[string]any `json:"-"` } +type CompatConfig struct { + WideInputStrictOutput bool `json:"wide_input_strict_output,omitempty"` +} + +type ToolcallConfig struct { + Mode string `json:"mode,omitempty"` + EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"` +} + +type ResponsesConfig struct { + StoreTTLSeconds int `json:"store_ttl_seconds,omitempty"` +} + +type EmbeddingsConfig struct { + Provider string `json:"provider,omitempty"` +} + func (c Config) MarshalJSON() ([]byte, error) { m := map[string]any{} for k, v := range c.AdditionalFields { @@ -84,6 +106,21 @@ func (c Config) MarshalJSON() ([]byte, error) { if len(c.ClaudeModelMap) > 0 { m["claude_model_mapping"] = c.ClaudeModelMap } + if len(c.ModelAliases) > 0 { + m["model_aliases"] = c.ModelAliases + } + if c.Compat.WideInputStrictOutput { + m["compat"] = c.Compat + } + if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { + m["toolcall"] = c.Toolcall + } + if c.Responses.StoreTTLSeconds > 0 { + m["responses"] = c.Responses + } + if strings.TrimSpace(c.Embeddings.Provider) != "" { + m["embeddings"] = c.Embeddings + } if c.VercelSyncHash != "" { m["_vercel_sync_hash"] = c.VercelSyncHash } @@ -117,6 +154,26 @@ func (c *Config) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) } + case "model_aliases": + if err := json.Unmarshal(v, &c.ModelAliases); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "compat": + if err := json.Unmarshal(v, &c.Compat); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "toolcall": + if err := json.Unmarshal(v, &c.Toolcall); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "responses": + if err := json.Unmarshal(v, &c.Responses); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "embeddings": + if err := json.Unmarshal(v, &c.Embeddings); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "_vercel_sync_hash": if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) @@ -141,6 +198,11 @@ func (c Config) Clone() Config { Accounts: slices.Clone(c.Accounts), ClaudeMapping: cloneStringMap(c.ClaudeMapping), ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + ModelAliases: cloneStringMap(c.ModelAliases), + Compat: c.Compat, + Toolcall: c.Toolcall, + Responses: c.Responses, + Embeddings: c.Embeddings, VercelSyncHash: c.VercelSyncHash, VercelSyncTime: c.VercelSyncTime, AdditionalFields: map[string]any{}, @@ -490,3 +552,59 @@ func (s *Store) ClaudeMapping() map[string]string { } return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} } + +func (s *Store) ModelAliases() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + out := DefaultModelAliases() + for k, v := range s.cfg.ModelAliases { + key := strings.TrimSpace(lower(k)) + val := strings.TrimSpace(lower(v)) + if key == "" || val == "" { + continue + } + out[key] = val + } + return out +} + +func (s *Store) CompatWideInputStrictOutput() bool { + // Current default policy is always wide-input / strict-output. + // Kept as a method so callers do not depend on storage shape. + return true +} + +func (s *Store) ToolcallMode() string { + s.mu.RLock() + defer s.mu.RUnlock() + mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode)) + if mode == "" { + return "feature_match" + } + return mode +} + +func (s *Store) ToolcallEarlyEmitConfidence() string { + s.mu.RLock() + defer s.mu.RUnlock() + level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence)) + if level == "" { + return "high" + } + return level +} + +func (s *Store) ResponsesStoreTTLSeconds() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Responses.StoreTTLSeconds > 0 { + return s.cfg.Responses.StoreTTLSeconds + } + return 900 +} + +func (s *Store) EmbeddingsProvider() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Embeddings.Provider) +} diff --git a/internal/config/model_alias_test.go b/internal/config/model_alias_test.go new file mode 100644 index 0000000..89e74b0 --- /dev/null +++ b/internal/config/model_alias_test.go @@ -0,0 +1,44 @@ +package config + +import "testing" + +func TestResolveModelDirectDeepSeek(t *testing.T) { + got, ok := ResolveModel(nil, "deepseek-chat") + if !ok || got != "deepseek-chat" { + t.Fatalf("expected deepseek-chat, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelAlias(t *testing.T) { + got, ok := ResolveModel(nil, "gpt-4.1") + if !ok || got != "deepseek-chat" { + t.Fatalf("expected alias gpt-4.1 -> deepseek-chat, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelHeuristicReasoner(t *testing.T) { + got, ok := ResolveModel(nil, "o3-super") + if !ok || got != "deepseek-reasoner" { + t.Fatalf("expected heuristic reasoner, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelUnknown(t *testing.T) { + _, ok := ResolveModel(nil, "totally-custom-model") + if ok { + t.Fatal("expected unknown model to fail resolve") + } +} + +func TestClaudeModelsResponsePaginationFields(t *testing.T) { + resp := ClaudeModelsResponse() + if _, ok := resp["first_id"]; !ok { + t.Fatalf("expected first_id in response: %#v", resp) + } + if _, ok := resp["last_id"]; !ok { + t.Fatalf("expected last_id in response: %#v", resp) + } + if _, ok := resp["has_more"]; !ok { + t.Fatalf("expected has_more in response: %#v", resp) + } +} diff --git a/internal/config/models.go b/internal/config/models.go index 13fa63d..017a2ee 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -1,5 +1,7 @@ package config +import "strings" + type ModelInfo struct { ID string `json:"id"` Object string `json:"object"` @@ -71,6 +73,91 @@ func GetModelConfig(model string) (thinking bool, search bool, ok bool) { } } +func IsSupportedDeepSeekModel(model string) bool { + _, _, ok := GetModelConfig(model) + return ok +} + +func DefaultModelAliases() map[string]string { + return map[string]string{ + "gpt-4o": "deepseek-chat", + "gpt-4.1": "deepseek-chat", + "gpt-4.1-mini": "deepseek-chat", + "gpt-4.1-nano": "deepseek-chat", + "gpt-5": "deepseek-chat", + "gpt-5-mini": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o1": "deepseek-reasoner", + "o1-mini": "deepseek-reasoner", + "o3": "deepseek-reasoner", + "o3-mini": "deepseek-reasoner", + "claude-sonnet-4-5": "deepseek-chat", + "claude-haiku-4-5": "deepseek-chat", + "claude-opus-4-6": "deepseek-reasoner", + "claude-3-5-sonnet": "deepseek-chat", + "claude-3-5-haiku": "deepseek-chat", + "claude-3-opus": "deepseek-reasoner", + "gemini-2.5-pro": "deepseek-chat", + "gemini-2.5-flash": "deepseek-chat", + "llama-3.1-70b-instruct": "deepseek-chat", + "qwen-max": "deepseek-chat", + } +} + +func ResolveModel(store *Store, requested string) (string, bool) { + model := lower(strings.TrimSpace(requested)) + if model == "" { + return "", false + } + if IsSupportedDeepSeekModel(model) { + return model, true + } + aliases := DefaultModelAliases() + if store != nil { + for k, v := range store.ModelAliases() { + aliases[lower(strings.TrimSpace(k))] = lower(strings.TrimSpace(v)) + } + } + if mapped, ok := aliases[model]; ok && IsSupportedDeepSeekModel(mapped) { + return mapped, true + } + if strings.HasPrefix(model, "deepseek-") { + return "", false + } + + knownFamily := false + for _, prefix := range []string{ + "gpt-", "o1", "o3", "claude-", "gemini-", "llama-", "qwen-", "mistral-", "command-", + } { + if strings.HasPrefix(model, prefix) { + knownFamily = true + break + } + } + if !knownFamily { + return "", false + } + + useReasoner := strings.Contains(model, "reason") || + strings.Contains(model, "reasoner") || + strings.HasPrefix(model, "o1") || + strings.HasPrefix(model, "o3") || + strings.Contains(model, "opus") || + strings.Contains(model, "r1") + useSearch := strings.Contains(model, "search") + + switch { + case useReasoner && useSearch: + return "deepseek-reasoner-search", true + case useReasoner: + return "deepseek-reasoner", true + case useSearch: + return "deepseek-chat-search", true + default: + return "deepseek-chat", true + } +} + func lower(s string) string { b := []byte(s) for i, c := range b { @@ -85,6 +172,28 @@ func OpenAIModelsResponse() map[string]any { return map[string]any{"object": "list", "data": DeepSeekModels} } -func ClaudeModelsResponse() map[string]any { - return map[string]any{"object": "list", "data": ClaudeModels} +func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) { + canonical, ok := ResolveModel(store, id) + if !ok { + return ModelInfo{}, false + } + for _, model := range DeepSeekModels { + if model.ID == canonical { + return model, true + } + } + return ModelInfo{}, false +} + +func ClaudeModelsResponse() map[string]any { + resp := map[string]any{"object": "list", "data": ClaudeModels} + if len(ClaudeModels) > 0 { + resp["first_id"] = ClaudeModels[0].ID + resp["last_id"] = ClaudeModels[len(ClaudeModels)-1].ID + } else { + resp["first_id"] = nil + resp["last_id"] = nil + } + resp["has_more"] = false + return resp } diff --git a/internal/server/router.go b/internal/server/router.go index c6339fb..a81f0cb 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -92,7 +92,7 @@ func cors(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go index decb96e..9e44b94 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls.go @@ -10,6 +10,7 @@ import ( var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") +var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```") type ParsedToolCall struct { Name string `json:"name"` @@ -20,6 +21,10 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { if strings.TrimSpace(text) == "" { return nil } + text = stripFencedCodeBlocks(text) + if strings.TrimSpace(text) == "" { + return nil + } candidates := buildToolCallCandidates(text) var parsed []ParsedToolCall @@ -45,11 +50,6 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed return nil } candidates := []string{trimmed} - if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") { - if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 { - candidates = append(candidates, strings.TrimSpace(m[1])) - } - } for _, candidate := range candidates { candidate = strings.TrimSpace(candidate) if candidate == "" { @@ -321,23 +321,14 @@ func looksLikeToolExampleContext(text string) bool { if t == "" { return false } - cues := []string{ - "```", - "示例", - "例子", - "for example", - "example", - "demo", - "请勿执行", - "不要执行", - "do not execute", + return strings.Contains(t, "```") +} + +func stripFencedCodeBlocks(text string) string { + if strings.TrimSpace(text) == "" { + return "" } - for _, cue := range cues { - if strings.Contains(t, cue) { - return true - } - } - return false + return fencedBlockPattern.ReplaceAllString(text, " ") } func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 509299c..f7c82d2 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -19,11 +19,8 @@ func TestParseToolCalls(t *testing.T) { func TestParseToolCallsFromFencedJSON(t *testing.T) { text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```" calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 1 { - t.Fatalf("expected 1 call, got %d", len(calls)) - } - if calls[0].Input["q"] != "news" { - t.Fatalf("unexpected args: %#v", calls[0].Input) + if len(calls) != 0 { + t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) } } diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go index 393aa88..cba0ceb 100644 --- a/internal/util/util_edge_test.go +++ b/internal/util/util_edge_test.go @@ -416,18 +416,6 @@ func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) { // ─── looksLikeToolExampleContext ───────────────────────────────────── -func TestLooksLikeToolExampleContextChinese(t *testing.T) { - if !looksLikeToolExampleContext("下面是示例") { - t.Fatal("expected true for Chinese example context") - } -} - -func TestLooksLikeToolExampleContextEnglish(t *testing.T) { - if !looksLikeToolExampleContext("here is an example of") { - t.Fatal("expected true for English example context") - } -} - func TestLooksLikeToolExampleContextNone(t *testing.T) { if looksLikeToolExampleContext("I will call the tool now") { t.Fatal("expected false for non-example context") diff --git a/opencode.json.example b/opencode.json.example index 2933e9f..ed18a63 100644 --- a/opencode.json.example +++ b/opencode.json.example @@ -9,6 +9,12 @@ "apiKey": "your-api-key" }, "models": { + "gpt-4o": { + "name": "GPT-4o (aliased to deepseek-chat)" + }, + "gpt-5-codex": { + "name": "GPT-5 Codex (aliased to deepseek-reasoner)" + }, "deepseek-chat": { "name": "DeepSeek Chat (DS2API)" }, @@ -18,5 +24,5 @@ } } }, - "model": "ds2api/deepseek-chat" + "model": "ds2api/gpt-5-codex" } From eb253a9d3aefd7e097e803a86560f36ad71063a5 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 23:35:17 +0800 Subject: [PATCH 12/52] feat: Introduce standard request normalization and response building for OpenAI and Claude, enhance tool call streaming, and improve caller identification. --- API.en.md | 2 +- API.md | 2 +- internal/adapter/claude/handler.go | 50 ++----- internal/adapter/claude/standard_request.go | 58 ++++++++ .../adapter/claude/standard_request_test.go | 38 ++++++ .../adapter/openai/embeddings_route_test.go | 96 +++++++++++++ internal/adapter/openai/handler.go | 71 +++++----- internal/adapter/openai/response_store.go | 30 ++++- .../openai/responses_embeddings_test.go | 12 +- internal/adapter/openai/responses_handler.go | 76 ++++++----- .../adapter/openai/responses_route_test.go | 125 +++++++++++++++++ internal/adapter/openai/standard_request.go | 104 +++++++++++++++ .../adapter/openai/standard_request_test.go | 60 +++++++++ internal/adapter/openai/vercel_stream.go | 37 ++--- internal/auth/request.go | 22 ++- internal/auth/request_test.go | 21 +++ internal/testsuite/runner.go | 126 ++++++++++++++++++ internal/util/standard_request.go | 30 +++++ 18 files changed, 805 insertions(+), 155 deletions(-) create mode 100644 internal/adapter/claude/standard_request.go create mode 100644 internal/adapter/claude/standard_request_test.go create mode 100644 internal/adapter/openai/embeddings_route_test.go create mode 100644 internal/adapter/openai/responses_route_test.go create mode 100644 internal/adapter/openai/standard_request.go create mode 100644 internal/adapter/openai/standard_request_test.go create mode 100644 internal/util/standard_request.go diff --git a/API.en.md b/API.en.md index babd1dc..ef1a6f3 100644 --- a/API.en.md +++ b/API.en.md @@ -309,7 +309,7 @@ data: [DONE] ### `GET /v1/responses/{response_id}` -Business auth required. Fetches cached responses created by `POST /v1/responses`. +Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read). > Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`). diff --git a/API.md b/API.md index fa07cfa..3770924 100644 --- a/API.md +++ b/API.md @@ -309,7 +309,7 @@ data: [DONE] ### `GET /v1/responses/{response_id}` -需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象。 +需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。 > 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。 diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index a7d3431..44240af 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -63,32 +63,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { writeClaudeError(w, http.StatusBadRequest, "invalid json") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + norm, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) return } - if _, ok := req["max_tokens"]; !ok { - req["max_tokens"] = 8192 - } - - normalized := normalizeClaudeMessages(messagesRaw) - payload := cloneMap(req) - payload["messages"] = normalized - toolsRequested, _ := req["tools"].([]any) - if len(toolsRequested) > 0 && !hasSystemMessage(normalized) { - payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...) - } - - dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store) - dsModel, _ := dsPayload["model"].(string) - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) - if !ok { - thinkingEnabled = false - searchEnabled = false - } - finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + stdReq := norm.Standard sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -100,14 +80,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") return } - requestPayload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } + requestPayload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) if err != nil { writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") @@ -120,15 +93,14 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { return } - toolNames := extractClaudeToolNames(toolsRequested) - if util.ToBool(req["stream"]) { - h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - result := sse.CollectStream(resp, thinkingEnabled, true) + result := sse.CollectStream(resp, stdReq.Thinking, true) fullText := result.Text fullThinking := result.Thinking - detected := util.ParseToolCalls(fullText, toolNames) + detected := util.ParseToolCalls(fullText, stdReq.ToolNames) content := make([]map[string]any, 0, 4) if fullThinking != "" { content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) @@ -154,12 +126,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), "type": "message", "role": "assistant", - "model": model, + "model": stdReq.ResponseModel, "content": content, "stop_reason": stopReason, "stop_sequence": nil, "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)), + "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", norm.NormalizedMessages)), "output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText), }, }) diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go new file mode 100644 index 0000000..de97c6a --- /dev/null +++ b/internal/adapter/claude/standard_request.go @@ -0,0 +1,58 @@ +package claude + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +type claudeNormalizedRequest struct { + Standard util.StandardRequest + NormalizedMessages []any +} + +func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + if _, ok := req["max_tokens"]; !ok { + req["max_tokens"] = 8192 + } + normalizedMessages := normalizeClaudeMessages(messagesRaw) + payload := cloneMap(req) + payload["messages"] = normalizedMessages + toolsRequested, _ := req["tools"].([]any) + if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) { + payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) + } + + dsPayload := util.ConvertClaudeToDeepSeek(payload, store) + dsModel, _ := dsPayload["model"].(string) + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) + if !ok { + thinkingEnabled = false + searchEnabled = false + } + finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + toolNames := extractClaudeToolNames(toolsRequested) + + return claudeNormalizedRequest{ + Standard: util.StandardRequest{ + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: payload["messages"].([]any), + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + }, + NormalizedMessages: normalizedMessages, + }, nil +} diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go new file mode 100644 index 0000000..7ffdfb8 --- /dev/null +++ b/internal/adapter/claude/standard_request_test.go @@ -0,0 +1,38 @@ +package claude + +import ( + "testing" + + "ds2api/internal/config" +) + +func TestNormalizeClaudeRequest(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "stream": true, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if norm.Standard.ResolvedModel == "" { + t.Fatalf("expected resolved model") + } + if !norm.Standard.Stream { + t.Fatalf("expected stream=true") + } + if len(norm.Standard.ToolNames) == 0 { + t.Fatalf("expected tool names") + } + if norm.Standard.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} diff --git a/internal/adapter/openai/embeddings_route_test.go b/internal/adapter/openai/embeddings_route_test.go new file mode 100644 index 0000000..4395d16 --- /dev/null +++ b/internal/adapter/openai/embeddings_route_test.go @@ -0,0 +1,96 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", cfgJSON) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func TestEmbeddingsRouteContract(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("unauthorized", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("ok", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + if out["object"] != "list" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + data, _ := out["data"].([]any) + if len(data) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(data)) + } + }) +} + +func TestEmbeddingsRouteProviderMissing(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code in response: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param in response: %#v", out) + } +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index a2a1c4d..fadca38 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -91,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - resolvedModel, ok := config.ResolveModel(h.Store, model) - if !ok { - writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - responseModel := strings.TrimSpace(model) - if responseModel == "" { - responseModel = resolvedModel - } - - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -124,25 +111,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - applyOpenAIChatPassThrough(req, payload) + payload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") return } - if util.ToBool(req["stream"]) { - h.handleStream(w, r, resp, sessionID, responseModel, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, responseModel, finalPrompt, thinkingEnabled, toolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) } func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { @@ -208,7 +187,8 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt created := time.Now().Unix() firstChunkSent := false - bufferToolContent := len(toolNames) > 0 + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() var toolSieve toolStreamSieveState toolCallsEmitted := false streamToolCallIDs := map[int]string{} @@ -377,6 +357,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt } for _, evt := range events { if len(evt.ToolCallDeltas) > 0 { + if !emitEarlyToolDeltas { + continue + } toolCallsEmitted = true tcDelta := map[string]any{ "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), @@ -568,17 +551,23 @@ func openAIErrorCode(status int) string { } func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { - for _, k := range []string{ - "temperature", - "top_p", - "max_tokens", - "max_completion_tokens", - "presence_penalty", - "frequency_penalty", - "stop", - } { - if v, ok := req[k]; ok { - payload[k] = v - } + for k, v := range collectOpenAIChatPassThrough(req) { + payload[k] = v } } + +func (h *Handler) toolcallFeatureMatchEnabled() bool { + if h == nil || h.Store == nil { + return true + } + mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode())) + return mode == "" || mode == "feature_match" +} + +func (h *Handler) toolcallEarlyEmitHighConfidence() bool { + if h == nil || h.Store == nil { + return true + } + level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence())) + return level == "" || level == "high" +} diff --git a/internal/adapter/openai/response_store.go b/internal/adapter/openai/response_store.go index 4f51dfa..63ebbaa 100644 --- a/internal/adapter/openai/response_store.go +++ b/internal/adapter/openai/response_store.go @@ -3,9 +3,12 @@ package openai import ( "sync" "time" + + "ds2api/internal/auth" ) type storedResponse struct { + Owner string Value map[string]any ExpiresAt time.Time } @@ -26,32 +29,47 @@ func newResponseStore(ttl time.Duration) *responseStore { } } -func (s *responseStore) put(id string, value map[string]any) { - if s == nil || id == "" || value == nil { +func responseStoreKey(owner, id string) string { + return owner + "\x00" + id +} + +func responseStoreOwner(a *auth.RequestAuth) string { + if a == nil { + return "" + } + return a.CallerID +} + +func (s *responseStore) put(owner, id string, value map[string]any) { + if s == nil || owner == "" || id == "" || value == nil { return } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.sweepLocked(now) - s.items[id] = storedResponse{ + s.items[responseStoreKey(owner, id)] = storedResponse{ + Owner: owner, Value: cloneAnyMap(value), ExpiresAt: now.Add(s.ttl), } } -func (s *responseStore) get(id string) (map[string]any, bool) { - if s == nil || id == "" { +func (s *responseStore) get(owner, id string) (map[string]any, bool) { + if s == nil || owner == "" || id == "" { return nil, false } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.sweepLocked(now) - item, ok := s.items[id] + item, ok := s.items[responseStoreKey(owner, id)] if !ok { return nil, false } + if item.Owner != owner { + return nil, false + } return cloneAnyMap(item.Value), true } diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index b23597d..d270e1a 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -54,8 +54,8 @@ func TestDeterministicEmbeddingStable(t *testing.T) { func TestResponseStorePutGet(t *testing.T) { st := newResponseStore(100 * time.Millisecond) - st.put("resp_1", map[string]any{"id": "resp_1"}) - got, ok := st.get("resp_1") + st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"}) + got, ok := st.get("owner_1", "resp_1") if !ok { t.Fatal("expected stored response") } @@ -63,3 +63,11 @@ func TestResponseStorePutGet(t *testing.T) { t.Fatalf("unexpected response payload: %#v", got) } } + +func TestResponseStoreTenantIsolation(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"}) + if _, ok := st.get("owner_b", "resp_1"); ok { + t.Fatal("expected owner_b to be isolated from owner_a response") + } +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 8fbb132..b70fe0b 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -12,19 +12,35 @@ import ( "github.com/google/uuid" "ds2api/internal/auth" - "ds2api/internal/config" "ds2api/internal/sse" "ds2api/internal/util" ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + id := strings.TrimSpace(chi.URLParam(r, "response_id")) if id == "" { writeOpenAIError(w, http.StatusBadRequest, "response_id is required.") return } + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } st := h.getResponseStore() - item, ok := st.get(id) + item, ok := st.get(owner, id) if !ok { writeOpenAIError(w, http.StatusNotFound, "Response not found.") return @@ -45,32 +61,22 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { } defer h.Auth.Release(a) r = r.WithContext(auth.WithAuth(r.Context(), a)) + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - - model, _ := req["model"].(string) - model = strings.TrimSpace(model) - if model == "" { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.") + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - resolvedModel, ok := config.ResolveModel(h.Store, model) - if !ok { - writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - - messagesRaw := responsesMessagesFromRequest(req) - if len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'input' or 'messages'.") - return - } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -86,15 +92,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - applyOpenAIChatPassThrough(req, payload) + payload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") @@ -102,14 +100,14 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { } responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") - if util.ToBool(req["stream"]) { - h.handleResponsesStream(w, r, resp, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - h.handleResponsesNonStream(w, resp, responseID, model, finalPrompt, thinkingEnabled, toolNames) + h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) } -func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -118,11 +116,11 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res } result := sse.CollectStream(resp, thinkingEnabled, true) responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) - h.getResponseStore().put(responseID, responseObj) + h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } -func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -160,7 +158,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, initialType = "thinking" } parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - bufferToolContent := len(toolNames) > 0 + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() var sieve toolStreamSieveState thinking := strings.Builder{} text := strings.Builder{} @@ -194,7 +193,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, if toolCallsEmitted { obj["status"] = "completed" } - h.getResponseStore().put(responseID, obj) + h.getResponseStore().put(owner, responseID, obj) sendEvent("response.completed", map[string]any{ "type": "response.completed", "response": obj, @@ -259,6 +258,9 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, }) } if len(evt.ToolCallDeltas) > 0 { + if !emitEarlyToolDeltas { + continue + } toolCallsEmitted = true sendEvent("response.output_tool_call.delta", map[string]any{ "type": "response.output_tool_call.delta", diff --git a/internal/adapter/openai/responses_route_test.go b/internal/adapter/openai/responses_route_test.go new file mode 100644 index 0000000..6db0c23 --- /dev/null +++ b/internal/adapter/openai/responses_route_test.go @@ -0,0 +1,125 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer "+token) + a, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine auth failed: %v", err) + } + return a +} + +func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerA := responseStoreOwner(authForToken(t, resolver, "token-a")) + h.getResponseStore().put(ownerA, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + t.Run("unauthorized", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("cross-tenant-not-found", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-b") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("same-tenant-ok", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-a") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body failed: %v", err) + } + if body["id"] != "resp_test" { + t.Fatalf("unexpected body: %#v", body) + } + }) +} + +func TestResponsesRouteValidationContract(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + tests := []struct { + name string + body string + }{ + {name: "missing_model", body: `{"input":"hello"}`}, + {name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body)) + req.Header.Set("Authorization", "Bearer token-a") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param: %#v", out) + } + }) + } +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go new file mode 100644 index 0000000..52344d4 --- /dev/null +++ b/internal/adapter/openai/standard_request.go @@ -0,0 +1,104 @@ +package openai + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + responseModel := strings.TrimSpace(model) + if responseModel == "" { + responseModel = resolvedModel + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_chat", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: resolvedModel, + ResponseModel: responseModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + // Keep width-control as an explicit policy hook even if current default is true. + allowWideInput := true + if store != nil { + allowWideInput = store.CompatWideInputStrictOutput() + } + var messagesRaw []any + if allowWideInput { + messagesRaw = responsesMessagesFromRequest(req) + } else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + messagesRaw = msgs + } + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_responses", + RequestedModel: model, + ResolvedModel: resolvedModel, + ResponseModel: model, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func collectOpenAIChatPassThrough(req map[string]any) map[string]any { + out := map[string]any{} + for _, k := range []string{ + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "presence_penalty", + "frequency_penalty", + "stop", + } { + if v, ok := req[k]; ok { + out[k] = v + } + } + return out +} diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go new file mode 100644 index 0000000..f3453a2 --- /dev/null +++ b/internal/adapter/openai/standard_request_test.go @@ -0,0 +1,60 @@ +package openai + +import ( + "testing" + + "ds2api/internal/config" +) + +func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{}`) + return config.LoadStore() +} + +func TestNormalizeOpenAIChatRequest(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-5-codex", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "temperature": 0.3, + "stream": true, + } + n, err := normalizeOpenAIChatRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-reasoner" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if !n.Stream { + t.Fatalf("expected stream=true") + } + if _, ok := n.PassThrough["temperature"]; !ok { + t.Fatalf("expected temperature passthrough") + } + if n.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} + +func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "instructions": "system", + } + n, err := normalizeOpenAIResponsesRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-chat" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if len(n.Messages) != 2 { + t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages)) + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index be8a590..c8bd6d0 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -56,24 +56,15 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - resolvedModel, ok := config.ResolveModel(h.Store, model) - if !ok { - writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) + if !stdReq.Stream { + writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - responseModel := strings.TrimSpace(model) - if responseModel == "" { - responseModel = resolvedModel - } - - finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -94,15 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - applyOpenAIChatPassThrough(req, payload) + payload := stdReq.CompletionPayload(sessionID) leaseID := h.holdStreamLease(a) if leaseID == "" { writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") @@ -112,10 +95,10 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusOK, map[string]any{ "session_id": sessionID, "lease_id": leaseID, - "model": responseModel, - "final_prompt": finalPrompt, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, + "model": stdReq.ResponseModel, + "final_prompt": stdReq.FinalPrompt, + "thinking_enabled": stdReq.Thinking, + "search_enabled": stdReq.Search, "deepseek_token": a.DeepSeekToken, "pow_header": powHeader, "payload": payload, diff --git a/internal/auth/request.go b/internal/auth/request.go index ea3d7f1..d7faf8d 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -2,6 +2,8 @@ package auth import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "net/http" "strings" @@ -22,6 +24,7 @@ var ( type RequestAuth struct { UseConfigToken bool DeepSeekToken string + CallerID string AccountID string Account config.Account TriedAccounts map[string]bool @@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { if callerKey == "" { return nil, ErrUnauthorized } + callerID := callerTokenID(callerKey) ctx := req.Context() if !r.Store.HasAPIKey(callerKey) { - return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil + return &RequestAuth{ + UseConfigToken: false, + DeepSeekToken: callerKey, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + }, nil } target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) acc, ok := r.Pool.AcquireWait(ctx, target, nil) @@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { } a := &RequestAuth{ UseConfigToken: true, + CallerID: callerID, AccountID: acc.Identifier(), Account: acc, TriedAccounts: map[string]bool{}, @@ -158,3 +169,12 @@ func extractCallerToken(req *http.Request) string { } return strings.TrimSpace(req.Header.Get("x-api-key")) } + +func callerTokenID(token string) string { + token = strings.TrimSpace(token) + if token == "" { + return "" + } + sum := sha256.Sum256([]byte(token)) + return "caller:" + hex.EncodeToString(sum[:8]) +} diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index 1d568f3..ee74092 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) { if auth.DeepSeekToken != "direct-token" { t.Fatalf("unexpected token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } } func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { @@ -58,6 +61,24 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { if auth.DeepSeekToken != "account-token" { t.Fatalf("unexpected account token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } +} + +func TestCallerTokenIDStable(t *testing.T) { + a := callerTokenID("token-a") + b := callerTokenID("token-a") + c := callerTokenID("token-b") + if a == "" || b == "" || c == "" { + t.Fatalf("expected non-empty caller ids") + } + if a != b { + t.Fatalf("expected stable caller id, got %q and %q", a, b) + } + if a == c { + t.Fatalf("expected different caller id for different tokens") + } } func TestDetermineMissingToken(t *testing.T) { diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go index b48bce5..e6ae9a6 100644 --- a/internal/testsuite/runner.go +++ b/internal/testsuite/runner.go @@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef { {ID: "healthz_ok", Run: r.caseHealthz}, {ID: "readyz_ok", Run: r.caseReadyz}, {ID: "models_openai", Run: r.caseModelsOpenAI}, + {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID}, {ID: "models_claude", Run: r.caseModelsClaude}, {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, {ID: "chat_stream_basic", Run: r.caseChatStream}, + {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream}, + {ID: "responses_stream_basic", Run: r.caseResponsesStream}, + {ID: "embeddings_contract", Run: r.caseEmbeddings}, {ID: "reasoner_stream", Run: r.caseReasonerStream}, {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, {ID: "toolcall_stream", Run: r.caseToolcallStream}, @@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { return nil } +func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) if err != nil { @@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { return nil } +func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请简要回答 hello", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body))) + responseID := asString(m["id"]) + cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body))) + if responseID != "" { + getResp, getErr := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/v1/responses/" + responseID, + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Retryable: true, + }) + if getErr != nil { + return getErr + } + cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode)) + } + return nil +} + +func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请流式回答 hello", + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + hasCreated := false + hasCompleted := false + for _, f := range frames { + switch asString(f["type"]) { + case "response.created": + hasCreated = true + case "response.completed": + hasCompleted = true + } + } + cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/embeddings", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": []string{"hello", "world"}, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + if resp.StatusCode == http.StatusOK { + cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body))) + data, _ := m["data"].([]any) + cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + errObj, _ := m["error"].(map[string]any) + _, hasCode := errObj["code"] + _, hasParam := errObj["param"] + cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { resp, err := cc.request(ctx, requestSpec{ Method: http.MethodPost, diff --git a/internal/util/standard_request.go b/internal/util/standard_request.go new file mode 100644 index 0000000..af73acf --- /dev/null +++ b/internal/util/standard_request.go @@ -0,0 +1,30 @@ +package util + +type StandardRequest struct { + Surface string + RequestedModel string + ResolvedModel string + ResponseModel string + Messages []any + FinalPrompt string + ToolNames []string + Stream bool + Thinking bool + Search bool + PassThrough map[string]any +} + +func (r StandardRequest) CompletionPayload(sessionID string) map[string]any { + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": r.FinalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": r.Thinking, + "search_enabled": r.Search, + } + for k, v := range r.PassThrough { + payload[k] = v + } + return payload +} From 895423852f3751a38f0e822da8bd2163180fb32b Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 23:35:37 +0800 Subject: [PATCH 13/52] refactor: extract Claude and OpenAI response rendering into new `util/render` package --- internal/adapter/claude/handler.go | 46 ++----- internal/adapter/openai/handler.go | 32 +---- internal/adapter/openai/responses_handler.go | 61 +-------- internal/util/render.go | 136 +++++++++++++++++++ internal/util/render_test.go | 72 ++++++++++ 5 files changed, 221 insertions(+), 126 deletions(-) create mode 100644 internal/util/render.go create mode 100644 internal/util/render_test.go diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index 44240af..bac315f 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -98,43 +98,15 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { return } result := sse.CollectStream(resp, stdReq.Thinking, true) - fullText := result.Text - fullThinking := result.Thinking - detected := util.ParseToolCalls(fullText, stdReq.ToolNames) - content := make([]map[string]any, 0, 4) - if fullThinking != "" { - content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) - } - stopReason := "end_turn" - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - content = append(content, map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), - "name": tc.Name, - "input": tc.Input, - }) - } - } else { - if fullText == "" { - fullText = "抱歉,没有生成有效的响应内容。" - } - content = append(content, map[string]any{"type": "text", "text": fullText}) - } - writeJSON(w, http.StatusOK, map[string]any{ - "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), - "type": "message", - "role": "assistant", - "model": stdReq.ResponseModel, - "content": content, - "stop_reason": stopReason, - "stop_sequence": nil, - "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", norm.NormalizedMessages)), - "output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText), - }, - }) + respBody := util.BuildClaudeMessageResponse( + fmt.Sprintf("msg_%d", time.Now().UnixNano()), + stdReq.ResponseModel, + norm.NormalizedMessages, + result.Thinking, + result.Text, + stdReq.ToolNames, + ) + writeJSON(w, http.StatusOK, respBody) } func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index fadca38..ce90804 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -136,36 +136,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - detected := util.ParseToolCalls(finalText, toolNames) - finishReason := "stop" - messageObj := map[string]any{"role": "assistant", "content": finalText} - if thinkingEnabled && finalThinking != "" { - messageObj["reasoning_content"] = finalThinking - } - if len(detected) > 0 { - finishReason = "tool_calls" - messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) - messageObj["content"] = nil - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - - writeJSON(w, http.StatusOK, map[string]any{ - "id": completionID, - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - }) + respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + writeJSON(w, http.StatusOK, respBody) } func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index b70fe0b..92dd891 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -6,7 +6,6 @@ import ( "io" "net/http" "strings" - "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" @@ -115,7 +114,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) - responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } @@ -189,7 +188,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, } } } - obj := buildResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames) + obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames) if toolCallsEmitted { obj["status"] = "completed" } @@ -282,62 +281,6 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, } } -func buildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - detected := util.ParseToolCalls(finalText, toolNames) - output := make([]any, 0, 2) - if len(detected) > 0 { - toolCalls := make([]any, 0, len(detected)) - for _, tc := range detected { - toolCalls = append(toolCalls, map[string]any{ - "type": "tool_call", - "name": tc.Name, - "arguments": tc.Input, - }) - } - output = append(output, map[string]any{ - "type": "tool_calls", - "tool_calls": toolCalls, - }) - } else { - content := []any{ - map[string]any{ - "type": "output_text", - "text": finalText, - }, - } - if finalThinking != "" { - content = append([]any{map[string]any{ - "type": "reasoning", - "text": finalThinking, - }}, content...) - } - output = append(output, map[string]any{ - "type": "message", - "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "role": "assistant", - "content": content, - }) - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - return map[string]any{ - "id": responseID, - "type": "response", - "object": "response", - "created_at": time.Now().Unix(), - "status": "completed", - "model": model, - "output": output, - "output_text": finalText, - "usage": map[string]any{ - "input_tokens": promptTokens, - "output_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - }, - } -} - func responsesMessagesFromRequest(req map[string]any) []any { if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { return prependInstructionMessage(msgs, req["instructions"]) diff --git a/internal/util/render.go b/internal/util/render.go new file mode 100644 index 0000000..ffb8128 --- /dev/null +++ b/internal/util/render.go @@ -0,0 +1,136 @@ +package util + +import ( + "fmt" + "strings" + "time" + + "github.com/google/uuid" +) + +func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + } +} + +func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + output := make([]any, 0, 2) + if len(detected) > 0 { + toolCalls := make([]any, 0, len(detected)) + for _, tc := range detected { + toolCalls = append(toolCalls, map[string]any{ + "type": "tool_call", + "name": tc.Name, + "arguments": tc.Input, + }) + } + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": toolCalls, + }) + } else { + content := []any{ + map[string]any{ + "type": "output_text", + "text": finalText, + }, + } + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": finalText, + "usage": map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + }, + } +} + +func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + content := make([]map[string]any, 0, 4) + if finalThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if finalText == "" { + finalText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": finalText}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), + "output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText), + }, + } +} diff --git a/internal/util/render_test.go b/internal/util/render_test.go new file mode 100644 index 0000000..1ee296b --- /dev/null +++ b/internal/util/render_test.go @@ -0,0 +1,72 @@ +package util + +import "testing" + +func TestBuildOpenAIChatCompletionWithToolCalls(t *testing.T) { + out := BuildOpenAIChatCompletion( + "cid1", + "deepseek-chat", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["object"] != "chat.completion" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + choices, _ := out["choices"].([]map[string]any) + if len(choices) == 0 { + // json-like map from generic marshalling may be []any in some paths + rawChoices, _ := out["choices"].([]any) + if len(rawChoices) == 0 { + t.Fatalf("expected choices") + } + c0, _ := rawChoices[0].(map[string]any) + if c0["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", c0["finish_reason"]) + } + return + } + if choices[0]["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choices[0]["finish_reason"]) + } +} + +func TestBuildOpenAIResponseObjectWithText(t *testing.T) { + out := BuildOpenAIResponseObject( + "resp_1", + "gpt-4o", + "prompt", + "reasoning", + "text", + nil, + ) + if out["object"] != "response" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + output, _ := out["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected output entries") + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected first output type message, got %#v", first["type"]) + } +} + +func TestBuildClaudeMessageResponseToolUse(t *testing.T) { + out := BuildClaudeMessageResponse( + "msg_1", + "claude-sonnet-4-5", + []any{map[string]any{"role": "user", "content": "hi"}}, + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["type"] != "message" { + t.Fatalf("unexpected type: %#v", out["type"]) + } + if out["stop_reason"] != "tool_use" { + t.Fatalf("expected stop_reason=tool_use, got %#v", out["stop_reason"]) + } +} From 51c543631baafd37b0bbc3bb1b51d08e6290a6b3 Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 23:40:34 +0800 Subject: [PATCH 14/52] refactor: Extract OpenAI streaming response payload construction into dedicated utility functions. --- internal/adapter/openai/handler.go | 78 ++++++---------- internal/adapter/openai/responses_handler.go | 55 ++---------- internal/util/render_stream.go | 93 ++++++++++++++++++++ internal/util/render_stream_test.go | 48 ++++++++++ 4 files changed, 176 insertions(+), 98 deletions(-) create mode 100644 internal/util/render_stream.go create mode 100644 internal/util/render_stream_test.go diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index ce90804..5ef6e7b 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -206,13 +206,13 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt delta["role"] = "assistant" firstChunkSent = true } - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": delta, "index": 0}}, - }) + sendChunk(util.BuildOpenAIChatStreamChunk( + completionID, + created, + model, + []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, + nil, + )) } else if bufferToolContent { for _, evt := range flushToolSieve(&toolSieve, toolNames) { if evt.Content == "" { @@ -225,36 +225,25 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt delta["role"] = "assistant" firstChunkSent = true } - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": delta, "index": 0}}, - }) + sendChunk(util.BuildOpenAIChatStreamChunk( + completionID, + created, + model, + []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, + nil, + )) } } if len(detected) > 0 || toolCallsEmitted { finishReason = "tool_calls" } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - }) + sendChunk(util.BuildOpenAIChatStreamChunk( + completionID, + created, + model, + []map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)}, + util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText), + )) sendDone() } @@ -340,10 +329,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt tcDelta["role"] = "assistant" firstChunkSent = true } - newChoices = append(newChoices, map[string]any{ - "delta": tcDelta, - "index": 0, - }) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) continue } if len(evt.ToolCalls) > 0 { @@ -355,10 +341,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt tcDelta["role"] = "assistant" firstChunkSent = true } - newChoices = append(newChoices, map[string]any{ - "delta": tcDelta, - "index": 0, - }) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) continue } if evt.Content != "" { @@ -369,26 +352,17 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt contentDelta["role"] = "assistant" firstChunkSent = true } - newChoices = append(newChoices, map[string]any{ - "delta": contentDelta, - "index": 0, - }) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta)) } } } } if len(delta) > 0 { - newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0}) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta)) } } if len(newChoices) > 0 { - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": newChoices, - }) + sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil)) } } } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 92dd891..9aaa7cd 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -144,13 +144,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, } } - sendEvent("response.created", map[string]any{ - "type": "response.created", - "id": responseID, - "object": "response", - "model": model, - "status": "in_progress", - }) + sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model)) initialType := "text" if thinkingEnabled { @@ -172,19 +166,11 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, for _, evt := range flushToolSieve(&sieve, toolNames) { if evt.Content != "" { finalText += evt.Content - sendEvent("response.output_text.delta", map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "delta": evt.Content, - }) + sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) } if len(evt.ToolCalls) > 0 { toolCallsEmitted = true - sendEvent("response.output_tool_call.done", map[string]any{ - "type": "response.output_tool_call.done", - "id": responseID, - "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), - }) + sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) } } } @@ -193,10 +179,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, obj["status"] = "completed" } h.getResponseStore().put(owner, responseID, obj) - sendEvent("response.completed", map[string]any{ - "type": "response.completed", - "response": obj, - }) + sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj)) _, _ = w.Write([]byte("data: [DONE]\n\n")) if canFlush { _ = rc.Flush() @@ -232,48 +215,28 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, continue } thinking.WriteString(p.Text) - sendEvent("response.reasoning.delta", map[string]any{ - "type": "response.reasoning.delta", - "id": responseID, - "delta": p.Text, - }) + sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text)) continue } text.WriteString(p.Text) if !bufferToolContent { - sendEvent("response.output_text.delta", map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "delta": p.Text, - }) + sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text)) continue } for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) { if evt.Content != "" { - sendEvent("response.output_text.delta", map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "delta": evt.Content, - }) + sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) } if len(evt.ToolCallDeltas) > 0 { if !emitEarlyToolDeltas { continue } toolCallsEmitted = true - sendEvent("response.output_tool_call.delta", map[string]any{ - "type": "response.output_tool_call.delta", - "id": responseID, - "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), - }) + sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs))) } if len(evt.ToolCalls) > 0 { toolCallsEmitted = true - sendEvent("response.output_tool_call.done", map[string]any{ - "type": "response.output_tool_call.done", - "id": responseID, - "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), - }) + sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) } } } diff --git a/internal/util/render_stream.go b/internal/util/render_stream.go new file mode 100644 index 0000000..716c158 --- /dev/null +++ b/internal/util/render_stream.go @@ -0,0 +1,93 @@ +package util + +func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { + return map[string]any{ + "delta": delta, + "index": index, + } +} + +func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "index": index, + "finish_reason": finishReason, + } +} + +func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { + out := map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + } + if len(usage) > 0 { + out["usage"] = usage + } + return out +} + +func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + } +} + +func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any { + return map[string]any{ + "type": "response.created", + "id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + } +} + +func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.delta", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any { + return map[string]any{ + "type": "response.completed", + "response": response, + } +} diff --git a/internal/util/render_stream_test.go b/internal/util/render_stream_test.go new file mode 100644 index 0000000..420a311 --- /dev/null +++ b/internal/util/render_stream_test.go @@ -0,0 +1,48 @@ +package util + +import "testing" + +func TestBuildOpenAIChatStreamChunk(t *testing.T) { + chunk := BuildOpenAIChatStreamChunk( + "cid", + 123, + "deepseek-chat", + []map[string]any{BuildOpenAIChatStreamDeltaChoice(0, map[string]any{"role": "assistant"})}, + nil, + ) + if chunk["object"] != "chat.completion.chunk" { + t.Fatalf("unexpected object: %#v", chunk["object"]) + } + choices, _ := chunk["choices"].([]map[string]any) + if len(choices) == 0 { + rawChoices, _ := chunk["choices"].([]any) + if len(rawChoices) == 0 { + t.Fatalf("expected choices") + } + } +} + +func TestBuildOpenAIChatUsage(t *testing.T) { + usage := BuildOpenAIChatUsage("prompt", "think", "answer") + if _, ok := usage["prompt_tokens"]; !ok { + t.Fatalf("expected prompt_tokens") + } + if _, ok := usage["completion_tokens_details"]; !ok { + t.Fatalf("expected completion_tokens_details") + } +} + +func TestBuildOpenAIResponsesEventPayloads(t *testing.T) { + created := BuildOpenAIResponsesCreatedPayload("resp_1", "gpt-4o") + if created["type"] != "response.created" { + t.Fatalf("unexpected type: %#v", created["type"]) + } + done := BuildOpenAIResponsesToolCallDonePayload("resp_1", []map[string]any{{"index": 0}}) + if done["type"] != "response.output_tool_call.done" { + t.Fatalf("unexpected type: %#v", done["type"]) + } + completed := BuildOpenAIResponsesCompletedPayload(map[string]any{"id": "resp_1"}) + if completed["type"] != "response.completed" { + t.Fatalf("unexpected type: %#v", completed["type"]) + } +} From 2dcc230852580ea5c0bc05c5193f40267f303b4b Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 23:53:50 +0800 Subject: [PATCH 15/52] feat: Introduce `DetermineCaller` for auth without account pooling and make `wide_input_strict_output` configurable. --- api/chat-stream.js | 43 +++++++++++++++- api/chat-stream.test.js | 42 ++++++++++++++- internal/adapter/openai/responses_handler.go | 10 +--- .../adapter/openai/responses_route_test.go | 51 +++++++++++++++++++ internal/adapter/openai/vercel_stream.go | 21 ++++---- internal/auth/request.go | 20 ++++++++ internal/auth/request_test.go | 33 ++++++++++++ internal/config/config.go | 35 +++++++++---- internal/config/config_edge_test.go | 33 ++++++++++++ 9 files changed, 257 insertions(+), 31 deletions(-) diff --git a/api/chat-stream.js b/api/chat-stream.js index 680651d..1a8e896 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -85,7 +85,8 @@ module.exports = async function handler(req, res) { const finalPrompt = asString(prep.body.final_prompt); const thinkingEnabled = toBool(prep.body.thinking_enabled); const searchEnabled = toBool(prep.body.search_enabled); - const toolNames = extractToolNames(payload.tools); + const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); + const toolNames = toolPolicy.toolNames; if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { writeOpenAIError(res, 500, 'invalid vercel prepare response'); @@ -156,7 +157,8 @@ module.exports = async function handler(req, res) { let currentType = thinkingEnabled ? 'thinking' : 'text'; let thinkingText = ''; let outputText = ''; - const toolSieveEnabled = toolNames.length > 0; + const toolSieveEnabled = toolPolicy.toolSieveEnabled; + const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas; const toolSieveState = createToolSieveState(); let toolCallsEmitted = false; const streamToolCallIDs = new Map(); @@ -297,6 +299,9 @@ module.exports = async function handler(req, res) { const events = processToolSieveChunk(toolSieveState, p.text, toolNames); for (const evt of events) { if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + if (!emitEarlyToolDeltas) { + continue; + } toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); continue; @@ -407,6 +412,37 @@ function relayPreparedFailure(res, prep) { writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); } +function resolveToolcallPolicy(prepBody, payloadTools) { + const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); + const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); + const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); + return { + toolNames, + toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled, + emitEarlyToolDeltas, + }; +} + +function normalizePreparedToolNames(v) { + if (!Array.isArray(v) || v.length === 0) { + return []; + } + const out = []; + for (const item of v) { + const name = asString(item); + if (!name) { + continue; + } + out.push(name); + } + return out; +} + +function boolDefaultTrue(v) { + return v !== false; +} + async function safeReadText(resp) { if (!resp) { return ''; @@ -933,4 +969,7 @@ module.exports.__test = { extractContentRecursive, shouldSkipPath, asString, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, }; diff --git a/api/chat-stream.test.js b/api/chat-stream.test.js index c849f7c..7424df2 100644 --- a/api/chat-stream.test.js +++ b/api/chat-stream.test.js @@ -10,10 +10,50 @@ const { flushToolSieve, } = require('./helpers/stream-tool-sieve'); -const { parseChunkForContent } = handler.__test; +const { + parseChunkForContent, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, +} = handler.__test; test('chat-stream exposes parser test hooks', () => { assert.equal(typeof parseChunkForContent, 'function'); + assert.equal(typeof resolveToolcallPolicy, 'function'); +}); + +test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => { + const policy = resolveToolcallPolicy( + {}, + [{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['read_file']); + assert.equal(policy.toolSieveEnabled, true); + assert.equal(policy.emitEarlyToolDeltas, true); +}); + +test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => { + const policy = resolveToolcallPolicy( + { + tool_names: [' prepped_tool ', '', null], + toolcall_feature_match: false, + toolcall_early_emit_high: false, + }, + [{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['prepped_tool']); + assert.equal(policy.toolSieveEnabled, false); + assert.equal(policy.emitEarlyToolDeltas, false); +}); + +test('normalizePreparedToolNames filters empty values', () => { + assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']); +}); + +test('boolDefaultTrue keeps false only when explicitly false', () => { + assert.equal(boolDefaultTrue(false), false); + assert.equal(boolDefaultTrue(true), true); + assert.equal(boolDefaultTrue(undefined), true); }); test('parseChunkForContent keeps split response/content fragments inside response array', () => { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 9aaa7cd..ff324b4 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -16,17 +16,11 @@ import ( ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { - a, err := h.Auth.Determine(r) + a, err := h.Auth.DetermineCaller(r) if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeOpenAIError(w, status, detail) + writeOpenAIError(w, http.StatusUnauthorized, err.Error()) return } - defer h.Auth.Release(a) id := strings.TrimSpace(chi.URLParam(r, "response_id")) if id == "" { diff --git a/internal/adapter/openai/responses_route_test.go b/internal/adapter/openai/responses_route_test.go index 6db0c23..574c6fa 100644 --- a/internal/adapter/openai/responses_route_test.go +++ b/internal/adapter/openai/responses_route_test.go @@ -26,6 +26,22 @@ func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) { return store, resolver } +func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}] + }`) + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0") + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth { t.Helper() req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) @@ -123,3 +139,38 @@ func TestResponsesRouteValidationContract(t *testing.T) { }) } } + +func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) { + store, resolver := newManagedKeyResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + ownerReq.Header.Set("Authorization", "Bearer managed-key") + ownerAuth, err := resolver.DetermineCaller(ownerReq) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + owner := responseStoreOwner(ownerAuth) + h.getResponseStore().put(owner, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + occupyReq.Header.Set("Authorization", "Bearer managed-key") + occupied, err := resolver.Determine(occupyReq) + if err != nil { + t.Fatalf("expected first acquire to succeed: %v", err) + } + defer resolver.Release(occupied) + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer managed-key") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index c8bd6d0..65006c4 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -93,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque } leased = true writeJSON(w, http.StatusOK, map[string]any{ - "session_id": sessionID, - "lease_id": leaseID, - "model": stdReq.ResponseModel, - "final_prompt": stdReq.FinalPrompt, - "thinking_enabled": stdReq.Thinking, - "search_enabled": stdReq.Search, - "deepseek_token": a.DeepSeekToken, - "pow_header": powHeader, - "payload": payload, + "session_id": sessionID, + "lease_id": leaseID, + "model": stdReq.ResponseModel, + "final_prompt": stdReq.FinalPrompt, + "thinking_enabled": stdReq.Thinking, + "search_enabled": stdReq.Search, + "tool_names": stdReq.ToolNames, + "toolcall_feature_match": h.toolcallFeatureMatchEnabled(), + "toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(), + "deepseek_token": a.DeepSeekToken, + "pow_header": powHeader, + "payload": payload, }) } diff --git a/internal/auth/request.go b/internal/auth/request.go index d7faf8d..25980cf 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -83,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { return a, nil } +// DetermineCaller resolves caller identity without acquiring any pooled account. +// Use this for local-cache lookup routes that only need tenant isolation. +func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { + callerKey := extractCallerToken(req) + if callerKey == "" { + return nil, ErrUnauthorized + } + callerID := callerTokenID(callerKey) + a := &RequestAuth{ + UseConfigToken: false, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + } + if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) { + a.DeepSeekToken = callerKey + } + return a, nil +} + func WithAuth(ctx context.Context, a *RequestAuth) context.Context { return context.WithValue(ctx, authCtxKey, a) } diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index ee74092..c292856 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -66,6 +66,26 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { } } +func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.DetermineCaller(req) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + if a.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } + if a.UseConfigToken { + t.Fatalf("expected no config-token lease for caller-only auth") + } + if a.AccountID != "" { + t.Fatalf("expected empty account id, got %q", a.AccountID) + } +} + func TestCallerTokenIDStable(t *testing.T) { a := callerTokenID("token-a") b := callerTokenID("token-a") @@ -93,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestDetermineCallerMissingToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + + _, err := r.DetermineCaller(req) + if err == nil { + t.Fatal("expected unauthorized error") + } + if err != ErrUnauthorized { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index d583159..d391462 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -73,7 +73,7 @@ type Config struct { } type CompatConfig struct { - WideInputStrictOutput bool `json:"wide_input_strict_output,omitempty"` + WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` } type ToolcallConfig struct { @@ -109,7 +109,7 @@ func (c Config) MarshalJSON() ([]byte, error) { if len(c.ModelAliases) > 0 { m["model_aliases"] = c.ModelAliases } - if c.Compat.WideInputStrictOutput { + if c.Compat.WideInputStrictOutput != nil { m["compat"] = c.Compat } if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { @@ -194,12 +194,14 @@ func (c *Config) UnmarshalJSON(b []byte) error { func (c Config) Clone() Config { clone := Config{ - Keys: slices.Clone(c.Keys), - Accounts: slices.Clone(c.Accounts), - ClaudeMapping: cloneStringMap(c.ClaudeMapping), - ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), - ModelAliases: cloneStringMap(c.ModelAliases), - Compat: c.Compat, + Keys: slices.Clone(c.Keys), + Accounts: slices.Clone(c.Accounts), + ClaudeMapping: cloneStringMap(c.ClaudeMapping), + ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + ModelAliases: cloneStringMap(c.ModelAliases), + Compat: CompatConfig{ + WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), + }, Toolcall: c.Toolcall, Responses: c.Responses, Embeddings: c.Embeddings, @@ -224,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string { return out } +func cloneBoolPtr(in *bool) *bool { + if in == nil { + return nil + } + v := *in + return &v +} + type Store struct { mu sync.RWMutex cfg Config @@ -569,9 +579,12 @@ func (s *Store) ModelAliases() map[string]string { } func (s *Store) CompatWideInputStrictOutput() bool { - // Current default policy is always wide-input / strict-output. - // Kept as a method so callers do not depend on storage shape. - return true + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Compat.WideInputStrictOutput == nil { + return true + } + return *s.cfg.Compat.WideInputStrictOutput } func (s *Store) ToolcallMode() string { diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go index 81cc7ec..1138867 100644 --- a/internal/config/config_edge_test.go +++ b/internal/config/config_edge_test.go @@ -320,6 +320,39 @@ func TestStoreFindAccountNotFound(t *testing.T) { } } +func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.CompatWideInputStrictOutput() { + t.Fatal("expected default wide_input_strict_output=true when unset") + } +} + +func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`) + store := LoadStore() + if store.CompatWideInputStrictOutput() { + t.Fatal("expected wide_input_strict_output=false when explicitly configured") + } + + snap := store.Snapshot() + data, err := snap.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + rawCompat, ok := out["compat"].(map[string]any) + if !ok { + t.Fatalf("expected compat in marshaled output, got %#v", out) + } + if rawCompat["wide_input_strict_output"] != false { + t.Fatalf("expected explicit false in compat, got %#v", rawCompat) + } +} + func TestStoreIsEnvBacked(t *testing.T) { t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) store := LoadStore() From df9aea194c8c0cf7d599327255fa36477666f7c2 Mon Sep 17 00:00:00 2001 From: CJACK Date: Thu, 19 Feb 2026 00:08:03 +0800 Subject: [PATCH 16/52] fix: Remove redundant text accumulation to prevent duplicate output in streamed responses and add a test for it. --- internal/adapter/openai/responses_handler.go | 1 - .../adapter/openai/responses_stream_test.go | 70 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 internal/adapter/openai/responses_stream_test.go diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index ff324b4..e04fb5f 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -159,7 +159,6 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, if bufferToolContent { for _, evt := range flushToolSieve(&sieve, toolNames) { if evt.Content != "" { - finalText += evt.Content sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) } if len(evt.ToolCalls) > 0 { diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go new file mode 100644 index 0000000..4633388 --- /dev/null +++ b/internal/adapter/openai/responses_stream_test.go @@ -0,0 +1,70 @@ +package openai + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestHandleResponsesStreamNoDuplicateTailInCompletedOutputText(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + tail := `{"tool_calls":[{"name":"read_file","input":` + streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + outputText, _ := responseObj["output_text"].(string) + if strings.Count(outputText, tail) != 1 { + t.Fatalf("expected tail to appear once in output_text, got output_text=%q", outputText) + } +} + +func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { + scanner := bufio.NewScanner(strings.NewReader(body)) + matched := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "event: ") { + evt := strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + matched = evt == targetEvent + continue + } + if !matched || !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" || raw == "[DONE]" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return nil, false + } + return payload, true + } + return nil, false +} From d21aedac83a2145280097d3fc7b7853e9dcf62f6 Mon Sep 17 00:00:00 2001 From: CJACK Date: Thu, 19 Feb 2026 00:28:44 +0800 Subject: [PATCH 17/52] feat: Hide raw tool call JSON from `output_text` in OpenAI-style responses when structured tool calls are present. --- .../adapter/openai/responses_stream_test.go | 58 ++++++++++++++++++- internal/util/render.go | 6 +- internal/util/render_test.go | 22 +++++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 4633388..9b0a5ac 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -10,7 +10,59 @@ import ( "testing" ) -func TestHandleResponsesStreamNoDuplicateTailInCompletedOutputText(t *testing.T) { +func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}` + streamBody := sseLine(rawToolJSON) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + outputText, _ := responseObj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText) + } + output, _ := responseObj["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected structured output entries, got %#v", responseObj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "tool_calls" { + t.Fatalf("expected first output type tool_calls, got %#v", first["type"]) + } + toolCalls, _ := first["tool_calls"].([]any) + if len(toolCalls) == 0 { + t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"]) + } + call0, _ := toolCalls[0].(map[string]any) + if call0["name"] != "read_file" { + t.Fatalf("unexpected tool call name: %#v", call0["name"]) + } + if strings.Contains(outputText, `"tool_calls"`) { + t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText) + } +} + +func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -38,8 +90,8 @@ func TestHandleResponsesStreamNoDuplicateTailInCompletedOutputText(t *testing.T) } responseObj, _ := completed["response"].(map[string]any) outputText, _ := responseObj["output_text"].(string) - if strings.Count(outputText, tail) != 1 { - t.Fatalf("expected tail to appear once in output_text, got output_text=%q", outputText) + if strings.Count(outputText, tail) > 1 { + t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText) } } diff --git a/internal/util/render.go b/internal/util/render.go index ffb8128..b5e0a79 100644 --- a/internal/util/render.go +++ b/internal/util/render.go @@ -43,8 +43,12 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) + exposedOutputText := finalText output := make([]any, 0, 2) if len(detected) > 0 { + // Keep structured tool output only; avoid leaking raw tool-call JSON + // into response.output_text for clients reading completed responses. + exposedOutputText = "" toolCalls := make([]any, 0, len(detected)) for _, tc := range detected { toolCalls = append(toolCalls, map[string]any{ @@ -88,7 +92,7 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi "status": "completed", "model": model, "output": output, - "output_text": finalText, + "output_text": exposedOutputText, "usage": map[string]any{ "input_tokens": promptTokens, "output_tokens": reasoningTokens + completionTokens, diff --git a/internal/util/render_test.go b/internal/util/render_test.go index 1ee296b..9d4feec 100644 --- a/internal/util/render_test.go +++ b/internal/util/render_test.go @@ -54,6 +54,28 @@ func TestBuildOpenAIResponseObjectWithText(t *testing.T) { } } +func TestBuildOpenAIResponseObjectToolCallsHidesRawOutputText(t *testing.T) { + out := BuildOpenAIResponseObject( + "resp_2", + "gpt-4o", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["output_text"] != "" { + t.Fatalf("expected empty output_text for tool_calls, got %#v", out["output_text"]) + } + output, _ := out["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected output entries") + } + first, _ := output[0].(map[string]any) + if first["type"] != "tool_calls" { + t.Fatalf("expected first output type tool_calls, got %#v", first["type"]) + } +} + func TestBuildClaudeMessageResponseToolUse(t *testing.T) { out := BuildClaudeMessageResponse( "msg_1", From 7307a5cc9a7c7f6076ce8d05cfb55f23fc58fbaf Mon Sep 17 00:00:00 2001 From: CJACK Date: Thu, 19 Feb 2026 02:45:38 +0800 Subject: [PATCH 18/52] feat: Implement admin settings UI, enhance admin authentication with password hashing, and add new streaming runtime logic for Claude and OpenAI adapters with extensive compatibility tests. --- TESTING.md | 4 +- api/chat-stream.js | 41 +- api/compat/js_compat_test.js | 60 +++ api/shared/deepseek-constants.js | 66 +++ internal/account/pool.go | 67 +++- internal/adapter/claude/convert.go | 11 + internal/adapter/claude/deps.go | 29 ++ .../adapter/claude/deps_injection_test.go | 33 ++ internal/adapter/claude/handler.go | 303 ++------------ internal/adapter/claude/standard_request.go | 7 +- internal/adapter/claude/stream_runtime.go | 308 ++++++++++++++ .../adapter/openai/chat_stream_runtime.go | 237 +++++++++++ internal/adapter/openai/deps.go | 35 ++ .../adapter/openai/deps_injection_test.go | 70 ++++ internal/adapter/openai/handler.go | 241 ++--------- internal/adapter/openai/prompt_build.go | 6 +- internal/adapter/openai/responses_handler.go | 138 ++----- .../openai/responses_stream_runtime.go | 168 ++++++++ internal/adapter/openai/standard_request.go | 4 +- internal/admin/deps.go | 46 +++ internal/admin/handler.go | 15 +- internal/admin/handler_auth.go | 14 +- internal/admin/handler_config.go | 202 ++++++++-- internal/admin/handler_settings.go | 321 +++++++++++++++ internal/admin/handler_settings_test.go | 267 +++++++++++++ internal/admin/handler_vercel.go | 241 +++++++---- internal/admin/helpers.go | 2 +- internal/admin/request_error.go | 23 ++ internal/admin/settings_validation.go | 64 +++ internal/auth/admin.go | 127 +++++- internal/auth/admin_test.go | 57 +++ internal/claudeconv/convert.go | 48 +++ internal/compat/go_compat_test.go | 142 +++++++ internal/config/config.go | 120 ++++++ internal/config/models.go | 8 +- internal/deepseek/constants.go | 76 +++- internal/deepseek/constants_shared.json | 25 ++ internal/deepseek/constants_test.go | 15 + internal/deepseek/prompt.go | 7 + internal/format/claude/render.go | 46 +++ internal/format/openai/render.go | 193 +++++++++ internal/prompt/messages.go | 84 ++++ internal/sse/parser.go | 240 ++++++----- internal/stream/engine.go | 128 ++++++ internal/testsuite/runner.go | 2 +- internal/util/messages.go | 112 +----- internal/util/render.go | 6 + internal/util/render_stream.go | 20 + .../compat/expected/sse_fragments_append.json | 8 + .../compat/expected/sse_nested_finished.json | 5 + .../compat/expected/sse_split_tool_json.json | 8 + tests/compat/expected/token_cases.json | 7 + .../expected/toolcalls_fenced_json.json | 3 + .../expected/toolcalls_unknown_name.json | 5 + .../fixtures/sse_chunks/fragments_append.json | 12 + .../fixtures/sse_chunks/nested_finished.json | 10 + .../fixtures/sse_chunks/split_tool_json.json | 11 + tests/compat/fixtures/token_cases.json | 7 + .../fixtures/toolcalls/fenced_json.json | 4 + .../fixtures/toolcalls/unknown_name.json | 4 + webui/src/App.jsx | 8 +- webui/src/components/Settings.jsx | 376 ++++++++++++++++++ webui/src/locales/en.json | 49 +++ webui/src/locales/zh.json | 49 +++ 64 files changed, 4078 insertions(+), 967 deletions(-) create mode 100644 api/compat/js_compat_test.js create mode 100644 api/shared/deepseek-constants.js create mode 100644 internal/adapter/claude/convert.go create mode 100644 internal/adapter/claude/deps.go create mode 100644 internal/adapter/claude/deps_injection_test.go create mode 100644 internal/adapter/claude/stream_runtime.go create mode 100644 internal/adapter/openai/chat_stream_runtime.go create mode 100644 internal/adapter/openai/deps.go create mode 100644 internal/adapter/openai/deps_injection_test.go create mode 100644 internal/adapter/openai/responses_stream_runtime.go create mode 100644 internal/admin/deps.go create mode 100644 internal/admin/handler_settings.go create mode 100644 internal/admin/handler_settings_test.go create mode 100644 internal/admin/request_error.go create mode 100644 internal/admin/settings_validation.go create mode 100644 internal/claudeconv/convert.go create mode 100644 internal/compat/go_compat_test.go create mode 100644 internal/deepseek/constants_shared.json create mode 100644 internal/deepseek/constants_test.go create mode 100644 internal/deepseek/prompt.go create mode 100644 internal/format/claude/render.go create mode 100644 internal/format/openai/render.go create mode 100644 internal/prompt/messages.go create mode 100644 internal/stream/engine.go create mode 100644 tests/compat/expected/sse_fragments_append.json create mode 100644 tests/compat/expected/sse_nested_finished.json create mode 100644 tests/compat/expected/sse_split_tool_json.json create mode 100644 tests/compat/expected/token_cases.json create mode 100644 tests/compat/expected/toolcalls_fenced_json.json create mode 100644 tests/compat/expected/toolcalls_unknown_name.json create mode 100644 tests/compat/fixtures/sse_chunks/fragments_append.json create mode 100644 tests/compat/fixtures/sse_chunks/nested_finished.json create mode 100644 tests/compat/fixtures/sse_chunks/split_tool_json.json create mode 100644 tests/compat/fixtures/token_cases.json create mode 100644 tests/compat/fixtures/toolcalls/fenced_json.json create mode 100644 tests/compat/fixtures/toolcalls/unknown_name.json create mode 100644 webui/src/components/Settings.jsx diff --git a/TESTING.md b/TESTING.md index 5540592..ce349ec 100644 --- a/TESTING.md +++ b/TESTING.md @@ -24,7 +24,7 @@ go test ./... ``` ```bash -node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js +node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js ``` ### 端到端测试 | End-to-End Tests @@ -39,7 +39,7 @@ node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js - `go test ./... -count=1`(单元测试) - `node --check api/chat-stream.js`(语法检查) - `node --check api/helpers/stream-tool-sieve.js`(语法检查) - - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js`(Node 流式拦截单测) + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测) - `npm run build --prefix webui`(WebUI 构建检查) 2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程 diff --git a/api/chat-stream.js b/api/chat-stream.js index 1a8e896..c4a4cd1 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -10,31 +10,14 @@ const { parseToolCalls, formatOpenAIStreamToolCalls, } = require('./helpers/stream-tool-sieve'); +const { + BASE_HEADERS, + SKIP_PATTERNS, + SKIP_EXACT_PATHS, +} = require('./shared/deepseek-constants'); const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; -const BASE_HEADERS = { - Host: 'chat.deepseek.com', - 'User-Agent': 'DeepSeek/1.6.11 Android/35', - Accept: 'application/json', - 'Content-Type': 'application/json', - 'x-client-platform': 'android', - 'x-client-version': '1.6.11', - 'x-client-locale': 'zh_CN', - 'accept-charset': 'UTF-8', -}; - -const SKIP_PATTERNS = [ - 'quasi_status', - 'elapsed_secs', - 'token_usage', - 'pending_fragment', - 'conversation_mode', - 'fragments/-1/status', - 'fragments/-2/status', - 'fragments/-3/status', -]; - module.exports = async function handler(req, res) { setCorsHeaders(res); if (req.method === 'OPTIONS') { @@ -725,7 +708,7 @@ function extractContentRecursive(items, defaultType) { } function shouldSkipPath(pathValue) { - if (pathValue === 'response/search_status') { + if (SKIP_EXACT_PATHS.has(pathValue)) { return true; } for (const p of SKIP_PATTERNS) { @@ -808,7 +791,16 @@ function estimateTokens(text) { if (!t) { return 0; } - const n = Math.floor(Array.from(t).length / 4); + let asciiChars = 0; + let nonASCIIChars = 0; + for (const ch of Array.from(t)) { + if (ch.charCodeAt(0) < 128) { + asciiChars += 1; + } else { + nonASCIIChars += 1; + } + } + const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13); return n < 1 ? 1 : n; } @@ -972,4 +964,5 @@ module.exports.__test = { resolveToolcallPolicy, normalizePreparedToolNames, boolDefaultTrue, + estimateTokens, }; diff --git a/api/compat/js_compat_test.js b/api/compat/js_compat_test.js new file mode 100644 index 0000000..9b03b00 --- /dev/null +++ b/api/compat/js_compat_test.js @@ -0,0 +1,60 @@ +'use strict'; + +const test = require('node:test'); +const assert = require('node:assert/strict'); +const fs = require('node:fs'); +const path = require('node:path'); + +const chatStream = require('../chat-stream'); +const { parseToolCalls } = require('../helpers/stream-tool-sieve'); + +const { parseChunkForContent, estimateTokens } = chatStream.__test; + +const compatRoot = path.resolve(__dirname, '../../tests/compat'); + +function readJSON(filePath) { + return JSON.parse(fs.readFileSync(filePath, 'utf8')); +} + +test('js compat: sse fixtures', () => { + const fixtureDir = path.join(compatRoot, 'fixtures', 'sse_chunks'); + const expectedDir = path.join(compatRoot, 'expected'); + const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort(); + assert.ok(files.length > 0); + + for (const file of files) { + const name = file.replace(/\.json$/i, ''); + const fixture = readJSON(path.join(fixtureDir, file)); + const expected = readJSON(path.join(expectedDir, `sse_${name}.json`)); + const got = parseChunkForContent(fixture.chunk, Boolean(fixture.thinking_enabled), fixture.current_type || 'text'); + assert.deepEqual(got.parts, expected.parts, `${name}: parts mismatch`); + assert.equal(got.finished, expected.finished, `${name}: finished mismatch`); + assert.equal(got.newType, expected.new_type, `${name}: newType mismatch`); + } +}); + +test('js compat: toolcall fixtures', () => { + const fixtureDir = path.join(compatRoot, 'fixtures', 'toolcalls'); + const expectedDir = path.join(compatRoot, 'expected'); + const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort(); + assert.ok(files.length > 0); + + for (const file of files) { + const name = file.replace(/\.json$/i, ''); + const fixture = readJSON(path.join(fixtureDir, file)); + const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`)); + const got = parseToolCalls(fixture.text, fixture.tool_names || []); + assert.deepEqual(got, expected.calls, `${name}: calls mismatch`); + } +}); + +test('js compat: token fixtures', () => { + const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json')); + const expected = readJSON(path.join(compatRoot, 'expected', 'token_cases.json')); + const expectedByName = new Map(expected.cases.map((c) => [c.name, c.tokens])); + for (const c of fixture.cases) { + assert.ok(expectedByName.has(c.name), `missing expected case: ${c.name}`); + const got = estimateTokens(c.text); + assert.equal(got, expectedByName.get(c.name), `${c.name}: tokens mismatch`); + } +}); diff --git a/api/shared/deepseek-constants.js b/api/shared/deepseek-constants.js new file mode 100644 index 0000000..1ec74f1 --- /dev/null +++ b/api/shared/deepseek-constants.js @@ -0,0 +1,66 @@ +'use strict'; + +const fs = require('fs'); +const path = require('path'); + +const DEFAULT_BASE_HEADERS = Object.freeze({ + Host: 'chat.deepseek.com', + 'User-Agent': 'DeepSeek/1.6.11 Android/35', + Accept: 'application/json', + 'Content-Type': 'application/json', + 'x-client-platform': 'android', + 'x-client-version': '1.6.11', + 'x-client-locale': 'zh_CN', + 'accept-charset': 'UTF-8', +}); + +const DEFAULT_SKIP_PATTERNS = Object.freeze([ + 'quasi_status', + 'elapsed_secs', + 'token_usage', + 'pending_fragment', + 'conversation_mode', + 'fragments/-1/status', + 'fragments/-2/status', + 'fragments/-3/status', +]); + +const DEFAULT_SKIP_EXACT_PATHS = Object.freeze([ + 'response/search_status', +]); + +function loadSharedConstants() { + const sharedPath = path.resolve(__dirname, '../../internal/deepseek/constants_shared.json'); + try { + const raw = fs.readFileSync(sharedPath, 'utf8'); + const parsed = JSON.parse(raw); + const baseHeaders = parsed && typeof parsed.base_headers === 'object' && !Array.isArray(parsed.base_headers) + ? { ...DEFAULT_BASE_HEADERS, ...parsed.base_headers } + : { ...DEFAULT_BASE_HEADERS }; + const skipPatterns = Array.isArray(parsed && parsed.skip_contains_patterns) + ? parsed.skip_contains_patterns.filter((v) => typeof v === 'string' && v !== '') + : [...DEFAULT_SKIP_PATTERNS]; + const skipExactPaths = Array.isArray(parsed && parsed.skip_exact_paths) + ? parsed.skip_exact_paths.filter((v) => typeof v === 'string' && v !== '') + : [...DEFAULT_SKIP_EXACT_PATHS]; + return { + baseHeaders, + skipPatterns, + skipExactPaths, + }; + } catch (_err) { + return { + baseHeaders: { ...DEFAULT_BASE_HEADERS }, + skipPatterns: [...DEFAULT_SKIP_PATTERNS], + skipExactPaths: [...DEFAULT_SKIP_EXACT_PATHS], + }; + } +} + +const shared = loadSharedConstants(); + +module.exports = { + BASE_HEADERS: Object.freeze(shared.baseHeaders), + SKIP_PATTERNS: Object.freeze(shared.skipPatterns), + SKIP_EXACT_PATHS: new Set(shared.skipExactPaths), +}; diff --git a/internal/account/pool.go b/internal/account/pool.go index 665bcee..12d8874 100644 --- a/internal/account/pool.go +++ b/internal/account/pool.go @@ -20,13 +20,18 @@ type Pool struct { maxInflightPerAccount int recommendedConcurrency int maxQueueSize int + globalMaxInflight int } func NewPool(store *config.Store) *Pool { + maxPer := 2 + if store != nil { + maxPer = store.RuntimeAccountMaxInflight() + } p := &Pool{ store: store, inUse: map[string]int{}, - maxInflightPerAccount: maxInflightFromEnv(), + maxInflightPerAccount: maxPer, } p.Reset() return p @@ -49,8 +54,18 @@ func (p *Pool) Reset() { ids = append(ids, id) } } + if p.store != nil { + p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight() + } else { + p.maxInflightPerAccount = maxInflightFromEnv() + } recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount) queueLimit := maxQueueFromEnv(recommended) + globalLimit := recommended + if p.store != nil { + queueLimit = p.store.RuntimeAccountMaxQueue(recommended) + globalLimit = p.store.RuntimeGlobalMaxInflight(recommended) + } p.mu.Lock() defer p.mu.Unlock() p.drainWaitersLocked() @@ -58,10 +73,12 @@ func (p *Pool) Reset() { p.inUse = map[string]int{} p.recommendedConcurrency = recommended p.maxQueueSize = queueLimit + p.globalMaxInflight = globalLimit config.Logger.Info( "[init_account_queue] initialized", "total", len(ids), "max_inflight_per_account", p.maxInflightPerAccount, + "global_max_inflight", p.globalMaxInflight, "recommended_concurrency", p.recommendedConcurrency, "max_queue_size", p.maxQueueSize, ) @@ -109,7 +126,7 @@ func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[strin func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) { if target != "" { - if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount { + if exclude[target] || !p.canAcquireIDLocked(target) { return config.Account{}, false } acc, ok := p.store.FindAccount(target) @@ -133,7 +150,7 @@ func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Acc func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) { for i := 0; i < len(p.queue); i++ { id := p.queue[i] - if exclude[id] || p.inUse[id] >= p.maxInflightPerAccount { + if exclude[id] || !p.canAcquireIDLocked(id) { continue } acc, ok := p.store.FindAccount(id) @@ -205,12 +222,35 @@ func (p *Pool) Status() map[string]any { "available_accounts": available, "in_use_accounts": inUseAccounts, "max_inflight_per_account": p.maxInflightPerAccount, + "global_max_inflight": p.globalMaxInflight, "recommended_concurrency": p.recommendedConcurrency, "waiting": len(p.waiters), "max_queue_size": p.maxQueueSize, } } +func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) { + if maxInflightPerAccount <= 0 { + maxInflightPerAccount = 1 + } + if maxQueueSize < 0 { + maxQueueSize = 0 + } + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts()) + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount + } + } + p.mu.Lock() + defer p.mu.Unlock() + p.maxInflightPerAccount = maxInflightPerAccount + p.maxQueueSize = maxQueueSize + p.globalMaxInflight = globalMaxInflight + p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount) + p.notifyWaiterLocked() +} + func maxInflightFromEnv() int { for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { raw := strings.TrimSpace(os.Getenv(key)) @@ -300,3 +340,24 @@ func maxQueueFromEnv(defaultSize int) int { } return defaultSize } + +func (p *Pool) canAcquireIDLocked(accountID string) bool { + if accountID == "" { + return false + } + if p.inUse[accountID] >= p.maxInflightPerAccount { + return false + } + if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight { + return false + } + return true +} + +func (p *Pool) currentInUseLocked() int { + total := 0 + for _, n := range p.inUse { + total += n + } + return total +} diff --git a/internal/adapter/claude/convert.go b/internal/adapter/claude/convert.go new file mode 100644 index 0000000..dbb5e1a --- /dev/null +++ b/internal/adapter/claude/convert.go @@ -0,0 +1,11 @@ +package claude + +import ( + "ds2api/internal/claudeconv" +) + +const defaultClaudeModel = "claude-sonnet-4-5" + +func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any { + return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel) +} diff --git a/internal/adapter/claude/deps.go b/internal/adapter/claude/deps.go new file mode 100644 index 0000000..73203b2 --- /dev/null +++ b/internal/adapter/claude/deps.go @@ -0,0 +1,29 @@ +package claude + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ClaudeMapping() map[string]string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/claude/deps_injection_test.go b/internal/adapter/claude/deps_injection_test.go new file mode 100644 index 0000000..39dfc2f --- /dev/null +++ b/internal/adapter/claude/deps_injection_test.go @@ -0,0 +1,33 @@ +package claude + +import "testing" + +type mockClaudeConfig struct { + m map[string]string +} + +func (m mockClaudeConfig) ClaudeMapping() map[string]string { return m.m } + +func TestNormalizeClaudeRequestUsesConfigInterfaceMapping(t *testing.T) { + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + out, err := normalizeClaudeRequest(mockClaudeConfig{ + m: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner-search", + }, + }, req) + if err != nil { + t.Fatalf("normalizeClaudeRequest error: %v", err) + } + if out.Standard.ResolvedModel != "deepseek-reasoner-search" { + t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel) + } + if !out.Standard.Thinking || !out.Standard.Search { + t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search) + } +} diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index bac315f..282b569 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -13,7 +13,9 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/deepseek" + claudefmt "ds2api/internal/format/claude" "ds2api/internal/sse" + streamengine "ds2api/internal/stream" "ds2api/internal/util" ) @@ -21,9 +23,9 @@ import ( var writeJSON = util.WriteJSON type Handler struct { - Store *config.Store - Auth *auth.Resolver - DS *deepseek.Client + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller } var ( @@ -98,7 +100,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { return } result := sse.CollectStream(resp, stdReq.Thinking, true) - respBody := util.BuildClaudeMessageResponse( + respBody := claudefmt.BuildMessageResponse( fmt.Sprintf("msg_%d", time.Now().UnixNano()), stdReq.ResponseModel, norm.NormalizedMessages, @@ -169,279 +171,38 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } - send := func(event string, v any) { - b, _ := json.Marshal(v) - _, _ = w.Write([]byte("event: ")) - _, _ = w.Write([]byte(event)) - _, _ = w.Write([]byte("\n")) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - sendError := func(message string) { - msg := strings.TrimSpace(message) - if msg == "" { - msg = "upstream stream error" - } - send("error", map[string]any{ - "type": "error", - "error": map[string]any{ - "type": "api_error", - "message": msg, - "code": "internal_error", - "param": nil, - }, - }) - } - messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) - inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages)) - send("message_start", map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": messageID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, - }, - }) + streamRuntime := newClaudeStreamRuntime( + w, + rc, + canFlush, + model, + messages, + thinkingEnabled, + searchEnabled, + toolNames, + ) + streamRuntime.sendMessageStart() initialType := "text" if thinkingEnabled { initialType = "thinking" } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - bufferToolContent := len(toolNames) > 0 - hasContent := false - lastContent := time.Now() - keepaliveCount := 0 - - thinking := strings.Builder{} - text := strings.Builder{} - - nextBlockIndex := 0 - thinkingBlockOpen := false - thinkingBlockIndex := -1 - textBlockOpen := false - textBlockIndex := -1 - ended := false - - closeThinkingBlock := func() { - if !thinkingBlockOpen { - return - } - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": thinkingBlockIndex, - }) - thinkingBlockOpen = false - thinkingBlockIndex = -1 - } - closeTextBlock := func() { - if !textBlockOpen { - return - } - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": textBlockIndex, - }) - textBlockOpen = false - textBlockIndex = -1 - } - - finalize := func(stopReason string) { - if ended { - return - } - ended = true - - closeThinkingBlock() - closeTextBlock() - - finalThinking := thinking.String() - finalText := text.String() - - if bufferToolContent { - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - idx := nextBlockIndex + i - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), - "name": tc.Name, - "input": tc.Input, - }, - }) - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - nextBlockIndex += len(detected) - } else if finalText != "" { - idx := nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": idx, - "delta": map[string]any{ - "type": "text_delta", - "text": finalText, - }, - }) - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - } - - outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) - send("message_delta", map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": stopReason, - "stop_sequence": nil, - }, - "usage": map[string]any{ - "output_tokens": outputTokens, - }, - }) - send("message_stop", map[string]any{"type": "message_stop"}) - } - - pingTicker := time.NewTicker(claudeStreamPingInterval) - defer pingTicker.Stop() - - for { - select { - case <-r.Context().Done(): - return - case <-pingTicker.C: - if !hasContent { - keepaliveCount++ - if keepaliveCount >= claudeStreamMaxKeepaliveCnt { - finalize("end_turn") - return - } - } - if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout { - finalize("end_turn") - return - } - send("ping", map[string]any{"type": "ping"}) - case parsed, ok := <-parsedLines: - if !ok { - if err := <-done; err != nil { - sendError(err.Error()) - return - } - finalize("end_turn") - return - } - if !parsed.Parsed { - continue - } - if parsed.ErrorMessage != "" { - sendError(parsed.ErrorMessage) - return - } - if parsed.Stop { - finalize("end_turn") - return - } - - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { - continue - } - - hasContent = true - lastContent = time.Now() - keepaliveCount = 0 - - if p.Type == "thinking" { - if !thinkingEnabled { - continue - } - thinking.WriteString(p.Text) - closeTextBlock() - if !thinkingBlockOpen { - thinkingBlockIndex = nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": thinkingBlockIndex, - "content_block": map[string]any{ - "type": "thinking", - "thinking": "", - }, - }) - thinkingBlockOpen = true - } - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": thinkingBlockIndex, - "delta": map[string]any{ - "type": "thinking_delta", - "thinking": p.Text, - }, - }) - continue - } - - text.WriteString(p.Text) - if bufferToolContent { - continue - } - closeThinkingBlock() - if !textBlockOpen { - textBlockIndex = nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": textBlockIndex, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - textBlockOpen = true - } - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": textBlockIndex, - "delta": map[string]any{ - "type": "text_delta", - "text": p.Text, - }, - }) - } - } - } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: claudeStreamPingInterval, + IdleTimeout: claudeStreamIdleTimeout, + MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendPing() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: streamRuntime.onFinalize, + }) } func writeClaudeError(w http.ResponseWriter, status int, message string) { diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go index de97c6a..cdbb675 100644 --- a/internal/adapter/claude/standard_request.go +++ b/internal/adapter/claude/standard_request.go @@ -5,6 +5,7 @@ import ( "strings" "ds2api/internal/config" + "ds2api/internal/deepseek" "ds2api/internal/util" ) @@ -13,7 +14,7 @@ type claudeNormalizedRequest struct { NormalizedMessages []any } -func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) { +func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) { model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { @@ -30,14 +31,14 @@ func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNorm payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) } - dsPayload := util.ConvertClaudeToDeepSeek(payload, store) + dsPayload := convertClaudeToDeepSeek(payload, store) dsModel, _ := dsPayload["model"].(string) thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) if !ok { thinkingEnabled = false searchEnabled = false } - finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"])) toolNames := extractClaudeToolNames(toolsRequested) return claudeNormalizedRequest{ diff --git a/internal/adapter/claude/stream_runtime.go b/internal/adapter/claude/stream_runtime.go new file mode 100644 index 0000000..01e07a9 --- /dev/null +++ b/internal/adapter/claude/stream_runtime.go @@ -0,0 +1,308 @@ +package claude + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type claudeStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + toolNames []string + messages []any + + thinkingEnabled bool + searchEnabled bool + bufferToolContent bool + + messageID string + thinking strings.Builder + text strings.Builder + + nextBlockIndex int + thinkingBlockOpen bool + thinkingBlockIndex int + textBlockOpen bool + textBlockIndex int + ended bool + upstreamErr string +} + +func newClaudeStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + messages []any, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *claudeStreamRuntime { + return &claudeStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + messages: messages, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: len(toolNames) > 0, + toolNames: toolNames, + messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), + thinkingBlockIndex: -1, + textBlockIndex: -1, + } +} + +func (s *claudeStreamRuntime) send(event string, v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("event: ")) + _, _ = s.w.Write([]byte(event)) + _, _ = s.w.Write([]byte("\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *claudeStreamRuntime) sendError(message string) { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "upstream stream error" + } + s.send("error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": msg, + "code": "internal_error", + "param": nil, + }, + }) +} + +func (s *claudeStreamRuntime) sendPing() { + s.send("ping", map[string]any{"type": "ping"}) +} + +func (s *claudeStreamRuntime) sendMessageStart() { + inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) + s.send("message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": s.messageID, + "type": "message", + "role": "assistant", + "model": s.model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, + }, + }) +} + +func (s *claudeStreamRuntime) closeThinkingBlock() { + if !s.thinkingBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.thinkingBlockIndex, + }) + s.thinkingBlockOpen = false + s.thinkingBlockIndex = -1 +} + +func (s *claudeStreamRuntime) closeTextBlock() { + if !s.textBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.textBlockIndex, + }) + s.textBlockOpen = false + s.textBlockIndex = -1 +} + +func (s *claudeStreamRuntime) finalize(stopReason string) { + if s.ended { + return + } + s.ended = true + + s.closeThinkingBlock() + s.closeTextBlock() + + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferToolContent { + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + idx := s.nextBlockIndex + i + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), + "name": tc.Name, + "input": tc.Input, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + s.nextBlockIndex += len(detected) + } else if finalText != "" { + idx := s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "text_delta", + "text": finalText, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + } + + outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + s.send("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]any{ + "output_tokens": outputTokens, + }, + }) + s.send("message_stop", map[string]any{"type": "message_stop"}) +} + +func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ErrorMessage != "" { + s.upstreamErr = parsed.ErrorMessage + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.closeTextBlock() + if !s.thinkingBlockOpen { + s.thinkingBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.thinkingBlockIndex, + "content_block": map[string]any{ + "type": "thinking", + "thinking": "", + }, + }) + s.thinkingBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.thinkingBlockIndex, + "delta": map[string]any{ + "type": "thinking_delta", + "thinking": p.Text, + }, + }) + continue + } + + s.text.WriteString(p.Text) + if s.bufferToolContent { + continue + } + s.closeThinkingBlock() + if !s.textBlockOpen { + s.textBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.textBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.textBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.textBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": p.Text, + }, + }) + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} + +func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) { + if string(reason) == "upstream_error" { + s.sendError(s.upstreamErr) + return + } + if scannerErr != nil { + s.sendError(scannerErr.Error()) + return + } + s.finalize("end_turn") +} diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go new file mode 100644 index 0000000..0e64bc5 --- /dev/null +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -0,0 +1,237 @@ +package openai + +import ( + "encoding/json" + "net/http" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type chatStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + completionID string + created int64 + model string + finalPrompt string + toolNames []string + + thinkingEnabled bool + searchEnabled bool + + firstChunkSent bool + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + + toolSieve toolStreamSieveState + streamToolCallIDs map[int]string + thinking strings.Builder + text strings.Builder +} + +func newChatStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + completionID string, + created int64, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, +) *chatStreamRuntime { + return &chatStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + completionID: completionID, + created: created, + model: model, + finalPrompt: finalPrompt, + toolNames: toolNames, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + } +} + +func (s *chatStreamRuntime) sendKeepAlive() { + if !s.canFlush { + return + } + _, _ = s.w.Write([]byte(": keep-alive\n\n")) + _ = s.rc.Flush() +} + +func (s *chatStreamRuntime) sendChunk(v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *chatStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *chatStreamRuntime) finalize(finishReason string) { + finalThinking := s.thinking.String() + finalText := s.text.String() + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 && !s.toolCallsEmitted { + finishReason = "tool_calls" + delta := map[string]any{ + "tool_calls": util.FormatOpenAIStreamToolCalls(detected), + } + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, + nil, + )) + } else if s.bufferToolContent { + for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) { + if evt.Content == "" { + continue + } + delta := map[string]any{ + "content": evt.Content, + } + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, + nil, + )) + } + } + + if len(detected) > 0 || s.toolCallsEmitted { + finishReason = "tool_calls" + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)}, + openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText), + )) + s.sendDone() +} + +func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested} + } + + newChoices := make([]map[string]any, 0, len(parsed.Parts)) + contentSeen := false + for _, p := range parsed.Parts { + if s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Text == "" { + continue + } + contentSeen = true + delta := map[string]any{} + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + delta["reasoning_content"] = p.Text + } + } else { + s.text.WriteString(p.Text) + if !s.bufferToolContent { + delta["content"] = p.Text + } else { + events := processToolSieveChunk(&s.toolSieve, p.Text, s.toolNames) + for _, evt := range events { + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + s.toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)) + continue + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)) + continue + } + if evt.Content != "" { + contentDelta := map[string]any{ + "content": evt.Content, + } + if !s.firstChunkSent { + contentDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, contentDelta)) + } + } + } + } + if len(delta) > 0 { + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, delta)) + } + } + + if len(newChoices) > 0 { + s.sendChunk(openaifmt.BuildChatStreamChunk(s.completionID, s.created, s.model, newChoices, nil)) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/deps.go b/internal/adapter/openai/deps.go new file mode 100644 index 0000000..6688756 --- /dev/null +++ b/internal/adapter/openai/deps.go @@ -0,0 +1,35 @@ +package openai + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + DetermineCaller(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ModelAliases() map[string]string + CompatWideInputStrictOutput() bool + ToolcallMode() string + ToolcallEarlyEmitConfidence() string + ResponsesStoreTTLSeconds() int + EmbeddingsProvider() string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/openai/deps_injection_test.go b/internal/adapter/openai/deps_injection_test.go new file mode 100644 index 0000000..baa0c11 --- /dev/null +++ b/internal/adapter/openai/deps_injection_test.go @@ -0,0 +1,70 @@ +package openai + +import "testing" + +type mockOpenAIConfig struct { + aliases map[string]string + wideInput bool + toolMode string + earlyEmit string + responsesTTL int + embedProv string +} + +func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases } +func (m mockOpenAIConfig) CompatWideInputStrictOutput() bool { + return m.wideInput +} +func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMode } +func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit } +func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL } +func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv } + +func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) { + cfg := mockOpenAIConfig{ + aliases: map[string]string{ + "my-model": "deepseek-chat-search", + }, + wideInput: true, + } + req := map[string]any{ + "model": "my-model", + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + } + out, err := normalizeOpenAIChatRequest(cfg, req) + if err != nil { + t.Fatalf("normalizeOpenAIChatRequest error: %v", err) + } + if out.ResolvedModel != "deepseek-chat-search" { + t.Fatalf("resolved model mismatch: got=%q", out.ResolvedModel) + } + if !out.Search || out.Thinking { + t.Fatalf("unexpected model flags: thinking=%v search=%v", out.Thinking, out.Search) + } +} + +func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.T) { + req := map[string]any{ + "model": "deepseek-chat", + "input": "hi", + } + + _, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ + aliases: map[string]string{}, + wideInput: false, + }, req) + if err == nil { + t.Fatal("expected error when wide input is disabled and only input is provided") + } + + out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ + aliases: map[string]string{}, + wideInput: true, + }, req) + if err != nil { + t.Fatalf("unexpected error when wide input is enabled: %v", err) + } + if out.Surface != "openai_responses" { + t.Fatalf("unexpected surface: %q", out.Surface) + } +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 5ef6e7b..28a451c 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -16,7 +16,9 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" + streamengine "ds2api/internal/stream" "ds2api/internal/util" ) @@ -25,9 +27,9 @@ import ( var writeJSON = util.WriteJSON type Handler struct { - Store *config.Store - Auth *auth.Resolver - DS *deepseek.Client + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller leaseMu sync.Mutex streamLeases map[string]streamLease @@ -136,7 +138,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) writeJSON(w, http.StatusOK, respBody) } @@ -158,214 +160,49 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt } created := time.Now().Unix() - firstChunkSent := false bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() - var toolSieve toolStreamSieveState - toolCallsEmitted := false - streamToolCallIDs := map[int]string{} initialType := "text" if thinkingEnabled { initialType = "thinking" } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - thinking := strings.Builder{} - text := strings.Builder{} - lastContent := time.Now() - hasContent := false - keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second) - defer keepaliveTicker.Stop() - keepaliveCountWithoutContent := 0 - sendChunk := func(v any) { - b, _ := json.Marshal(v) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - sendDone := func() { - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if canFlush { - _ = rc.Flush() - } - } + streamRuntime := newChatStreamRuntime( + w, + rc, + canFlush, + completionID, + created, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + ) - finalize := func(finishReason string) { - finalThinking := thinking.String() - finalText := text.String() - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) > 0 && !toolCallsEmitted { - finishReason = "tool_calls" - delta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(detected), - } - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - sendChunk(util.BuildOpenAIChatStreamChunk( - completionID, - created, - model, - []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, - nil, - )) - } else if bufferToolContent { - for _, evt := range flushToolSieve(&toolSieve, toolNames) { - if evt.Content == "" { - continue - } - delta := map[string]any{ - "content": evt.Content, - } - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - sendChunk(util.BuildOpenAIChatStreamChunk( - completionID, - created, - model, - []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, - nil, - )) - } - } - if len(detected) > 0 || toolCallsEmitted { - finishReason = "tool_calls" - } - sendChunk(util.BuildOpenAIChatStreamChunk( - completionID, - created, - model, - []map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)}, - util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText), - )) - sendDone() - } - - for { - select { - case <-r.Context().Done(): - return - case <-keepaliveTicker.C: - if !hasContent { - keepaliveCountWithoutContent++ - if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount { - finalize("stop") - return - } - } - if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second { - finalize("stop") + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendKeepAlive() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: func(reason streamengine.StopReason, _ error) { + if string(reason) == "content_filter" { + streamRuntime.finalize("content_filter") return } - if canFlush { - _, _ = w.Write([]byte(": keep-alive\n\n")) - _ = rc.Flush() - } - case parsed, ok := <-parsedLines: - if !ok { - // Ensure scanner completion is observed only after all queued - // SSE lines are drained, avoiding early finalize races. - _ = <-done - finalize("stop") - return - } - if !parsed.Parsed { - continue - } - if parsed.ContentFilter || parsed.ErrorMessage != "" { - finalize("content_filter") - return - } - if parsed.Stop { - finalize("stop") - return - } - newChoices := make([]map[string]any, 0, len(parsed.Parts)) - for _, p := range parsed.Parts { - if searchEnabled && sse.IsCitation(p.Text) { - continue - } - if p.Text == "" { - continue - } - hasContent = true - lastContent = time.Now() - keepaliveCountWithoutContent = 0 - delta := map[string]any{} - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - if p.Type == "thinking" { - if thinkingEnabled { - thinking.WriteString(p.Text) - delta["reasoning_content"] = p.Text - } - } else { - text.WriteString(p.Text) - if !bufferToolContent { - delta["content"] = p.Text - } else { - events := processToolSieveChunk(&toolSieve, p.Text, toolNames) - if len(events) == 0 { - // Keep thinking delta only frame. - } - for _, evt := range events { - if len(evt.ToolCallDeltas) > 0 { - if !emitEarlyToolDeltas { - continue - } - toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), - } - if !firstChunkSent { - tcDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) - continue - } - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), - } - if !firstChunkSent { - tcDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) - continue - } - if evt.Content != "" { - contentDelta := map[string]any{ - "content": evt.Content, - } - if !firstChunkSent { - contentDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta)) - } - } - } - } - if len(delta) > 0 { - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta)) - } - } - if len(newChoices) > 0 { - sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil)) - } - } - } + streamRuntime.finalize("stop") + }, + }) } func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index a7bbc92..f83963f 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -1,6 +1,8 @@ package openai -import "ds2api/internal/util" +import ( + "ds2api/internal/deepseek" +) func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { messages := normalizeOpenAIMessagesForPrompt(messagesRaw) @@ -8,5 +10,5 @@ func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { messages, toolNames = injectToolPrompt(messages, tools) } - return util.MessagesPrepare(messages), toolNames + return deepseek.MessagesPrepare(messages), toolNames } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index e04fb5f..e767b2b 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -6,13 +6,16 @@ import ( "io" "net/http" "strings" + "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "ds2api/internal/auth" + "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" - "ds2api/internal/util" + streamengine "ds2api/internal/stream" ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { @@ -108,7 +111,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) - responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } @@ -127,114 +130,45 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, rc := http.NewResponseController(w) canFlush := rc.Flush() == nil - sendEvent := func(event string, payload map[string]any) { - b, _ := json.Marshal(payload) - _, _ = w.Write([]byte("event: " + event + "\n")) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - - sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model)) - initialType := "text" if thinkingEnabled { initialType = "thinking" } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() - var sieve toolStreamSieveState - thinking := strings.Builder{} - text := strings.Builder{} - toolCallsEmitted := false - streamToolCallIDs := map[int]string{} - finalize := func() { - finalThinking := thinking.String() - finalText := text.String() - if bufferToolContent { - for _, evt := range flushToolSieve(&sieve, toolNames) { - if evt.Content != "" { - sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) - } - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) - } - } - } - obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames) - if toolCallsEmitted { - obj["status"] = "completed" - } - h.getResponseStore().put(owner, responseID, obj) - sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj)) - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if canFlush { - _ = rc.Flush() - } - } + streamRuntime := newResponsesStreamRuntime( + w, + rc, + canFlush, + responseID, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + func(obj map[string]any) { + h.getResponseStore().put(owner, responseID, obj) + }, + ) + streamRuntime.sendCreated() - for { - select { - case <-r.Context().Done(): - return - case parsed, ok := <-parsedLines: - if !ok { - _ = <-done - finalize() - return - } - if !parsed.Parsed { - continue - } - if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { - finalize() - return - } - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { - continue - } - if p.Type == "thinking" { - if !thinkingEnabled { - continue - } - thinking.WriteString(p.Text) - sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text)) - continue - } - text.WriteString(p.Text) - if !bufferToolContent { - sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text)) - continue - } - for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) { - if evt.Content != "" { - sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) - } - if len(evt.ToolCallDeltas) > 0 { - if !emitEarlyToolDeltas { - continue - } - toolCallsEmitted = true - sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs))) - } - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) - } - } - } - } - } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: streamRuntime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + streamRuntime.finalize() + }, + }) } func responsesMessagesFromRequest(req map[string]any) []any { diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go new file mode 100644 index 0000000..f7e8b20 --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime.go @@ -0,0 +1,168 @@ +package openai + +import ( + "encoding/json" + "net/http" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type responsesStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + responseID string + model string + finalPrompt string + toolNames []string + + thinkingEnabled bool + searchEnabled bool + + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + + sieve toolStreamSieveState + thinking strings.Builder + text strings.Builder + streamToolCallIDs map[int]string + + persistResponse func(obj map[string]any) +} + +func newResponsesStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + responseID string, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, + persistResponse func(obj map[string]any), +) *responsesStreamRuntime { + return &responsesStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + responseID: responseID, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + toolNames: toolNames, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + persistResponse: persistResponse, + } +} + +func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("event: " + event + "\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) sendCreated() { + s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model)) +} + +func (s *responsesStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + if s.bufferToolContent { + for _, evt := range flushToolSieve(&s.sieve, s.toolNames) { + if evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + } + } + } + + obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) + if s.toolCallsEmitted { + obj["status"] = "completed" + } + if s.persistResponse != nil { + s.persistResponse(obj) + } + s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj)) + s.sendDone() +} + +func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) + continue + } + + s.text.WriteString(p.Text) + if !s.bufferToolContent { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) + continue + } + for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) { + if evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs))) + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + } + } + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index 52344d4..5883d03 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -8,7 +8,7 @@ import ( "ds2api/internal/util" ) -func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { @@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.S }, nil } -func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { model, _ := req["model"].(string) model = strings.TrimSpace(model) if model == "" { diff --git a/internal/admin/deps.go b/internal/admin/deps.go new file mode 100644 index 0000000..e92c37b --- /dev/null +++ b/internal/admin/deps.go @@ -0,0 +1,46 @@ +package admin + +import ( + "context" + "net/http" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type ConfigStore interface { + Snapshot() config.Config + Keys() []string + Accounts() []config.Account + FindAccount(identifier string) (config.Account, bool) + UpdateAccountToken(identifier, token string) error + Update(mutator func(*config.Config) error) error + ExportJSONAndBase64() (string, string, error) + IsEnvBacked() bool + SetVercelSync(hash string, ts int64) error + AdminPasswordHash() string + AdminJWTExpireHours() int + AdminJWTValidAfterUnix() int64 + RuntimeAccountMaxInflight() int + RuntimeAccountMaxQueue(defaultSize int) int + RuntimeGlobalMaxInflight(defaultSize int) int +} + +type PoolController interface { + Reset() + Status() map[string]any + ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) +} + +type DeepSeekCaller interface { + Login(ctx context.Context, acc config.Account) (string, error) + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +var _ ConfigStore = (*config.Store)(nil) +var _ PoolController = (*account.Pool)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 9d6151e..829b657 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -2,16 +2,12 @@ package admin import ( "github.com/go-chi/chi/v5" - - "ds2api/internal/account" - "ds2api/internal/config" - "ds2api/internal/deepseek" ) type Handler struct { - Store *config.Store - Pool *account.Pool - DS *deepseek.Client + Store ConfigStore + Pool PoolController + DS DeepSeekCaller } func RegisterRoutes(r chi.Router, h *Handler) { @@ -22,6 +18,11 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Get("/vercel/config", h.getVercelConfig) pr.Get("/config", h.getConfig) pr.Post("/config", h.updateConfig) + pr.Get("/settings", h.getSettings) + pr.Put("/settings", h.updateSettings) + pr.Post("/settings/password", h.updateSettingsPassword) + pr.Post("/config/import", h.configImport) + pr.Get("/config/export", h.configExport) pr.Post("/keys", h.addKey) pr.Delete("/keys/{key}", h.deleteKey) pr.Get("/accounts", h.listAccounts) diff --git a/internal/admin/handler_auth.go b/internal/admin/handler_auth.go index 0d3ec1f..9b96b2f 100644 --- a/internal/admin/handler_auth.go +++ b/internal/admin/handler_auth.go @@ -12,7 +12,7 @@ import ( func (h *Handler) requireAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := authn.VerifyAdminRequest(r); err != nil { + if err := authn.VerifyAdminRequestWithStore(r, h.Store); err != nil { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) return } @@ -25,18 +25,18 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) { _ = json.NewDecoder(r.Body).Decode(&req) adminKey, _ := req["admin_key"].(string) expireHours := intFrom(req["expire_hours"]) - if expireHours <= 0 { - expireHours = 24 - } - if adminKey != authn.AdminKey() { + if !authn.VerifyAdminCredential(adminKey, h.Store) { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"}) return } - token, err := authn.CreateJWT(expireHours) + token, err := authn.CreateJWTWithStore(expireHours, h.Store) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) return } + if expireHours <= 0 { + expireHours = h.Store.AdminJWTExpireHours() + } writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600}) } @@ -47,7 +47,7 @@ func (h *Handler) verify(w http.ResponseWriter, r *http.Request) { return } token := strings.TrimSpace(header[7:]) - payload, err := authn.VerifyJWT(token) + payload, err := authn.VerifyJWTWithStore(token, h.Store) if err != nil { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) return diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go index 2b672c3..dfbd005 100644 --- a/internal/admin/handler_config.go +++ b/internal/admin/handler_config.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "sort" "strings" "github.com/go-chi/chi/v5" @@ -204,38 +203,191 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { } func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { + h.configExport(w, nil) +} + +func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() jsonStr, b64, err := h.Store.ExportJSONAndBase64() if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) return } - writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64}) + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "config": snap, + "json": jsonStr, + "base64": b64, + }) +} + +func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode"))) + if mode == "" { + mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode"))) + } + if mode == "" { + mode = "merge" + } + if mode != "merge" && mode != "replace" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"}) + return + } + + payload := req + if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 { + payload = raw + } + rawJSON, err := json.Marshal(payload) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"}) + return + } + var incoming config.Config + if err := json.Unmarshal(rawJSON, &incoming); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + + importedKeys, importedAccounts := 0, 0 + err = h.Store.Update(func(c *config.Config) error { + next := c.Clone() + if mode == "replace" { + next = incoming.Clone() + next.VercelSyncHash = c.VercelSyncHash + next.VercelSyncTime = c.VercelSyncTime + importedKeys = len(next.Keys) + importedAccounts = len(next.Accounts) + } else { + existingKeys := map[string]struct{}{} + for _, k := range next.Keys { + existingKeys[k] = struct{}{} + } + for _, k := range incoming.Keys { + key := strings.TrimSpace(k) + if key == "" { + continue + } + if _, ok := existingKeys[key]; ok { + continue + } + existingKeys[key] = struct{}{} + next.Keys = append(next.Keys, key) + importedKeys++ + } + + existingAccounts := map[string]struct{}{} + for _, acc := range next.Accounts { + existingAccounts[acc.Identifier()] = struct{}{} + } + for _, acc := range incoming.Accounts { + id := acc.Identifier() + if id == "" { + continue + } + if _, ok := existingAccounts[id]; ok { + continue + } + existingAccounts[id] = struct{}{} + next.Accounts = append(next.Accounts, acc) + importedAccounts++ + } + + if len(incoming.ClaudeMapping) > 0 { + if next.ClaudeMapping == nil { + next.ClaudeMapping = map[string]string{} + } + for k, v := range incoming.ClaudeMapping { + next.ClaudeMapping[k] = v + } + } + if len(incoming.ClaudeModelMap) > 0 { + if next.ClaudeModelMap == nil { + next.ClaudeModelMap = map[string]string{} + } + for k, v := range incoming.ClaudeModelMap { + next.ClaudeModelMap[k] = v + } + } + + if len(incoming.ModelAliases) > 0 { + if next.ModelAliases == nil { + next.ModelAliases = map[string]string{} + } + for k, v := range incoming.ModelAliases { + next.ModelAliases[k] = v + } + } + if strings.TrimSpace(incoming.Toolcall.Mode) != "" { + next.Toolcall.Mode = incoming.Toolcall.Mode + } + if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" { + next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence + } + if incoming.Responses.StoreTTLSeconds > 0 { + next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds + } + if strings.TrimSpace(incoming.Embeddings.Provider) != "" { + next.Embeddings.Provider = incoming.Embeddings.Provider + } + if strings.TrimSpace(incoming.Admin.PasswordHash) != "" { + next.Admin.PasswordHash = incoming.Admin.PasswordHash + } + if incoming.Admin.JWTExpireHours > 0 { + next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours + } + if incoming.Admin.JWTValidAfterUnix > 0 { + next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix + } + if incoming.Runtime.AccountMaxInflight > 0 { + next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight + } + if incoming.Runtime.AccountMaxQueue > 0 { + next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue + } + if incoming.Runtime.GlobalMaxInflight > 0 { + next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight + } + } + + normalizeSettingsConfig(&next) + if err := validateSettingsConfig(next); err != nil { + return newRequestError(err.Error()) + } + + *c = next + return nil + }) + if err != nil { + if detail, ok := requestErrorDetail(err); ok { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail}) + return + } + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "mode": mode, + "imported_keys": importedKeys, + "imported_accounts": importedAccounts, + "message": "config imported", + }) } func (h *Handler) computeSyncHash() string { - snap := h.Store.Snapshot() - syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}} - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, a := range snap.Accounts { - m := map[string]any{} - if a.Email != "" { - m["email"] = a.Email - } - if a.Mobile != "" { - m["mobile"] = a.Mobile - } - if a.Password != "" { - m["password"] = a.Password - } - accounts = append(accounts, m) - } - sort.Slice(accounts, func(i, j int) bool { - ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"]) - aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"]) - return ai < aj - }) - syncable["accounts"] = accounts - b, _ := json.Marshal(syncable) + snap := h.Store.Snapshot().Clone() + snap.VercelSyncHash = "" + snap.VercelSyncTime = 0 + b, _ := json.Marshal(snap) sum := md5.Sum(b) return fmt.Sprintf("%x", sum) } diff --git a/internal/admin/handler_settings.go b/internal/admin/handler_settings.go new file mode 100644 index 0000000..06c234c --- /dev/null +++ b/internal/admin/handler_settings.go @@ -0,0 +1,321 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + authn "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight()) + needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "admin": map[string]any{ + "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "", + "jwt_expire_hours": h.Store.AdminJWTExpireHours(), + "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix, + "default_password_warning": authn.UsingDefaultAdminKey(h.Store), + }, + "runtime": map[string]any{ + "account_max_inflight": h.Store.RuntimeAccountMaxInflight(), + "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended), + "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended), + }, + "toolcall": snap.Toolcall, + "responses": snap.Responses, + "embeddings": snap.Embeddings, + "claude_mapping": settingsClaudeMapping(snap), + "model_aliases": snap.ModelAliases, + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + }) +} + +func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + if runtimeCfg != nil { + if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + } + + if err := h.Store.Update(func(c *config.Config) error { + if adminCfg != nil { + if adminCfg.JWTExpireHours > 0 { + c.Admin.JWTExpireHours = adminCfg.JWTExpireHours + } + } + if runtimeCfg != nil { + if runtimeCfg.AccountMaxInflight > 0 { + c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight + } + if runtimeCfg.AccountMaxQueue > 0 { + c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue + } + if runtimeCfg.GlobalMaxInflight > 0 { + c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight + } + } + if toolcallCfg != nil { + if strings.TrimSpace(toolcallCfg.Mode) != "" { + c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode) + } + if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" { + c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) + } + } + if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 { + c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds + } + if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" { + c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider) + } + if claudeMap != nil { + c.ClaudeMapping = claudeMap + c.ClaudeModelMap = nil + } + if aliasMap != nil { + c.ModelAliases = aliasMap + } + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.applyRuntimeSettings() + needsSync := config.IsVercel() || h.Store.IsEnvBacked() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "settings updated and hot reloaded", + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。", + }) +} + +func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error { + merged := current + if incoming != nil { + if incoming.AccountMaxInflight > 0 { + merged.AccountMaxInflight = incoming.AccountMaxInflight + } + if incoming.AccountMaxQueue > 0 { + merged.AccountMaxQueue = incoming.AccountMaxQueue + } + if incoming.GlobalMaxInflight > 0 { + merged.GlobalMaxInflight = incoming.GlobalMaxInflight + } + } + return validateRuntimeSettings(merged) +} + +func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + newPassword := strings.TrimSpace(fieldString(req, "new_password")) + if newPassword == "" { + newPassword = strings.TrimSpace(fieldString(req, "password")) + } + if len(newPassword) < 4 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"}) + return + } + + now := time.Now().Unix() + hash := authn.HashAdminPassword(newPassword) + if err := h.Store.Update(func(c *config.Config) error { + c.Admin.PasswordHash = hash + c.Admin.JWTValidAfterUnix = now + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "password updated", + "force_relogin": true, + "jwt_valid_after_unix": now, + }) +} + +func (h *Handler) applyRuntimeSettings() { + if h == nil || h.Store == nil || h.Pool == nil { + return + } + accountCount := len(h.Store.Accounts()) + maxPer := h.Store.RuntimeAccountMaxInflight() + recommended := defaultRuntimeRecommended(accountCount, maxPer) + maxQueue := h.Store.RuntimeAccountMaxQueue(recommended) + global := h.Store.RuntimeGlobalMaxInflight(recommended) + h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global) +} + +func defaultRuntimeRecommended(accountCount, maxPer int) int { + if maxPer <= 0 { + maxPer = 1 + } + if accountCount <= 0 { + return maxPer + } + return accountCount * maxPer +} + +func settingsClaudeMapping(c config.Config) map[string]string { + if len(c.ClaudeMapping) > 0 { + return c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + return c.ClaudeModelMap + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +} + +func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) { + var ( + adminCfg *config.AdminConfig + runtimeCfg *config.RuntimeConfig + toolcallCfg *config.ToolcallConfig + respCfg *config.ResponsesConfig + embCfg *config.EmbeddingsConfig + claudeMap map[string]string + aliasMap map[string]string + ) + + if raw, ok := req["admin"].(map[string]any); ok { + cfg := &config.AdminConfig{} + if v, exists := raw["jwt_expire_hours"]; exists { + n := intFrom(v) + if n < 1 || n > 720 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + cfg.JWTExpireHours = n + } + adminCfg = cfg + } + + if raw, ok := req["runtime"].(map[string]any); ok { + cfg := &config.RuntimeConfig{} + if v, exists := raw["account_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 256 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + cfg.AccountMaxInflight = n + } + if v, exists := raw["account_max_queue"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + cfg.AccountMaxQueue = n + } + if v, exists := raw["global_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + cfg.GlobalMaxInflight = n + } + if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + runtimeCfg = cfg + } + + if raw, ok := req["toolcall"].(map[string]any); ok { + cfg := &config.ToolcallConfig{} + if v, exists := raw["mode"]; exists { + mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch mode { + case "feature_match", "off": + cfg.Mode = mode + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if v, exists := raw["early_emit_confidence"]; exists { + level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch level { + case "high", "low", "off": + cfg.EarlyEmitConfidence = level + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + toolcallCfg = cfg + } + + if raw, ok := req["responses"].(map[string]any); ok { + cfg := &config.ResponsesConfig{} + if v, exists := raw["store_ttl_seconds"]; exists { + n := intFrom(v) + if n < 30 || n > 86400 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + cfg.StoreTTLSeconds = n + } + respCfg = cfg + } + + if raw, ok := req["embeddings"].(map[string]any); ok { + cfg := &config.EmbeddingsConfig{} + if v, exists := raw["provider"]; exists { + p := strings.TrimSpace(fmt.Sprintf("%v", v)) + if p == "" { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty") + } + cfg.Provider = p + } + embCfg = cfg + } + + if raw, ok := req["claude_mapping"].(map[string]any); ok { + claudeMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + claudeMap[key] = val + } + } + + if raw, ok := req["model_aliases"].(map[string]any); ok { + aliasMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + aliasMap[key] = val + } + } + + return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil +} diff --git a/internal/admin/handler_settings_test.go b/internal/admin/handler_settings_test.go new file mode 100644 index 0000000..3eb5114 --- /dev/null +++ b/internal/admin/handler_settings_test.go @@ -0,0 +1,267 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + authn "ds2api/internal/auth" +) + +func TestGetSettingsDefaultPasswordWarning(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "") + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil) + rec := httptest.NewRecorder() + h.getSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &body) + admin, _ := body["admin"].(map[string]any) + warn, _ := admin["default_password_warning"].(bool) + if !warn { + t.Fatalf("expected default password warning true, body=%v", body) + } +} + +func TestUpdateSettingsValidation(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 0, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } +} + +func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":8 + } + }`) + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 16, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) { + t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String()) + } +} + +func TestUpdateSettingsWithoutRuntimeSkipsMergedRuntimeValidation(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":4 + } + }`) + payload := map[string]any{ + "responses": map[string]any{ + "store_ttl_seconds": 600, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if got := h.Store.Snapshot().Responses.StoreTTLSeconds; got != 600 { + t.Fatalf("store_ttl_seconds=%d want=600", got) + } +} + +func TestUpdateSettingsHotReloadRuntime(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"email":"a@test.com","token":"t1"},{"email":"b@test.com","token":"t2"}] + }`) + + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 3, + "account_max_queue": 20, + "global_max_inflight": 5, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + status := h.Pool.Status() + if got := intFrom(status["max_inflight_per_account"]); got != 3 { + t.Fatalf("max_inflight_per_account=%d want=3", got) + } + if got := intFrom(status["max_queue_size"]); got != 20 { + t.Fatalf("max_queue_size=%d want=20", got) + } + if got := intFrom(status["global_max_inflight"]); got != 5 { + t.Fatalf("global_max_inflight=%d want=5", got) + } +} + +func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) { + hash := authn.HashAdminPassword("old-password") + h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`) + + token, err := authn.CreateJWTWithStore(1, h.Store) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + if _, err := authn.VerifyJWTWithStore(token, h.Store); err != nil { + t.Fatalf("verify before update failed: %v", err) + } + + body := map[string]any{"new_password": "new-password"} + b, _ := json.Marshal(body) + req := httptest.NewRequest(http.MethodPost, "/admin/settings/password", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettingsPassword(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + + if _, err := authn.VerifyJWTWithStore(token, h.Store); err == nil { + t.Fatal("expected old token to be invalid after password update") + } + if !authn.VerifyAdminCredential("new-password", h.Store) { + t.Fatal("expected new password credential to be accepted") + } +} + +func TestConfigImportMergeAndReplace(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"email":"a@test.com","password":"p1"}] + }`) + + merge := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "keys": []any{"k1", "k2"}, + "accounts": []any{ + map[string]any{"email": "a@test.com", "password": "p1"}, + map[string]any{"email": "b@test.com", "password": "p2"}, + }, + }, + } + mergeBytes, _ := json.Marshal(merge) + mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes)) + mergeRec := httptest.NewRecorder() + h.configImport(mergeRec, mergeReq) + if mergeRec.Code != http.StatusOK { + t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String()) + } + if got := len(h.Store.Keys()); got != 2 { + t.Fatalf("keys after merge=%d want=2", got) + } + if got := len(h.Store.Accounts()); got != 2 { + t.Fatalf("accounts after merge=%d want=2", got) + } + + replace := map[string]any{ + "mode": "replace", + "config": map[string]any{ + "keys": []any{"k9"}, + }, + } + replaceBytes, _ := json.Marshal(replace) + replaceReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(replaceBytes)) + replaceRec := httptest.NewRecorder() + h.configImport(replaceRec, replaceReq) + if replaceRec.Code != http.StatusOK { + t.Fatalf("replace status=%d body=%s", replaceRec.Code, replaceRec.Body.String()) + } + keys := h.Store.Keys() + if len(keys) != 1 || keys[0] != "k9" { + t.Fatalf("unexpected keys after replace: %#v", keys) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("accounts after replace=%d want=0", got) + } +} + +func TestConfigImportRejectsInvalidRuntimeBounds(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "mode": "replace", + "config": map[string]any{ + "keys": []any{"k2"}, + "runtime": map[string]any{ + "account_max_inflight": 300, + }, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.account_max_inflight")) { + t.Fatalf("expected runtime bound detail, got %s", rec.Body.String()) + } + keys := h.Store.Keys() + if len(keys) != 1 || keys[0] != "k1" { + t.Fatalf("store should remain unchanged, keys=%v", keys) + } +} + +func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":8 + } + }`) + payload := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 16, + }, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) { + t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String()) + } + snap := h.Store.Snapshot() + if snap.Runtime.AccountMaxInflight != 8 || snap.Runtime.GlobalMaxInflight != 8 { + t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime) + } +} diff --git a/internal/admin/handler_vercel.go b/internal/admin/handler_vercel.go index 189d8cc..2c6356c 100644 --- a/internal/admin/handler_vercel.go +++ b/internal/admin/handler_vercel.go @@ -3,8 +3,8 @@ package admin import ( "bytes" "context" - "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -19,6 +19,62 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) return } + opts, err := parseVercelSyncOptions(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + validated, failed := h.validateAccountsForVercelSync(r.Context(), opts.AutoValidate) + _, cfgB64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + client := &http.Client{Timeout: 30 * time.Second} + params := buildVercelParams(opts.TeamID) + headers := map[string]string{"Authorization": "Bearer " + opts.VercelToken} + + envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+opts.ProjectID+"/env", params, headers, nil) + if err != nil || status != http.StatusOK { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) + return + } + envs, _ := envResp["envs"].([]any) + status, err = upsertVercelEnv(r.Context(), client, opts.ProjectID, params, headers, envs, "DS2API_CONFIG_JSON", cfgB64) + if err != nil || (status != http.StatusOK && status != http.StatusCreated) { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) + return + } + savedCreds := h.saveVercelProjectCredentials(r.Context(), client, opts, params, headers, envs) + manual, deployURL := triggerVercelDeployment(r.Context(), client, opts.ProjectID, params, headers) + _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) + result := map[string]any{"success": true, "validated_accounts": validated} + if manual { + result["message"] = "配置已同步到 Vercel,请手动触发重新部署" + result["manual_deploy_required"] = true + } else { + result["message"] = "配置已同步,正在重新部署..." + result["deployment_url"] = deployURL + } + if len(failed) > 0 { + result["failed_accounts"] = failed + } + if len(savedCreds) > 0 { + result["saved_credentials"] = savedCreds + } + writeJSON(w, http.StatusOK, result) +} + +type vercelSyncOptions struct { + VercelToken string + ProjectID string + TeamID string + AutoValidate bool + SaveCreds bool + UsePreconfig bool +} + +func parseVercelSyncOptions(req map[string]any) (vercelSyncOptions, error) { vercelToken, _ := req["vercel_token"].(string) projectID, _ := req["project_id"].(string) teamID, _ := req["team_id"].(string) @@ -40,108 +96,117 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(teamID) == "" { teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID")) } + vercelToken = strings.TrimSpace(vercelToken) + projectID = strings.TrimSpace(projectID) + teamID = strings.TrimSpace(teamID) if vercelToken == "" || projectID == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"}) - return + return vercelSyncOptions{}, fmt.Errorf("需要 Vercel Token 和 Project ID") + } + return vercelSyncOptions{ + VercelToken: vercelToken, + ProjectID: projectID, + TeamID: teamID, + AutoValidate: autoValidate, + SaveCreds: saveCreds, + UsePreconfig: usePreconfig, + }, nil +} + +func buildVercelParams(teamID string) url.Values { + params := url.Values{} + if strings.TrimSpace(teamID) != "" { + params.Set("teamId", strings.TrimSpace(teamID)) + } + return params +} + +func (h *Handler) validateAccountsForVercelSync(ctx context.Context, enabled bool) (int, []string) { + if !enabled { + return 0, nil } validated, failed := 0, []string{} - if autoValidate { - for _, acc := range h.Store.Snapshot().Accounts { - if strings.TrimSpace(acc.Token) != "" { - continue - } - token, err := h.DS.Login(r.Context(), acc) - if err != nil { - failed = append(failed, acc.Identifier()) - } else { - validated++ - _ = h.Store.UpdateAccountToken(acc.Identifier(), token) - } - time.Sleep(500 * time.Millisecond) + for _, acc := range h.Store.Snapshot().Accounts { + if strings.TrimSpace(acc.Token) != "" { + continue } + token, err := h.DS.Login(ctx, acc) + if err != nil { + failed = append(failed, acc.Identifier()) + } else { + validated++ + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + time.Sleep(500 * time.Millisecond) } + return validated, failed +} - cfgJSON, _, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return +func upsertVercelEnv(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string, envs []any, key, value string) (int, error) { + existingID := findEnvID(envs, key) + if existingID != "" { + _, status, err := vercelRequest(ctx, client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingID, params, headers, map[string]any{"value": value}) + return status, err } - cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON)) - client := &http.Client{Timeout: 30 * time.Second} - params := url.Values{} - if teamID != "" { - params.Set("teamId", teamID) + _, status, err := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{ + "key": key, + "value": value, + "type": "encrypted", + "target": []string{"production", "preview"}, + }) + return status, err +} + +func (h *Handler) saveVercelProjectCredentials(ctx context.Context, client *http.Client, opts vercelSyncOptions, params url.Values, headers map[string]string, envs []any) []string { + if !opts.SaveCreds || opts.UsePreconfig { + return nil } - headers := map[string]string{"Authorization": "Bearer " + vercelToken} - envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil) - if err != nil || status != http.StatusOK { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) - return + saved := []string{} + creds := [][2]string{{"VERCEL_TOKEN", opts.VercelToken}, {"VERCEL_PROJECT_ID", opts.ProjectID}} + if opts.TeamID != "" { + creds = append(creds, [2]string{"VERCEL_TEAM_ID", opts.TeamID}) } - envs, _ := envResp["envs"].([]any) - existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON") - if existingEnvID != "" { - _, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64}) - } else { - _, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}}) - } - if err != nil || (status != http.StatusOK && status != http.StatusCreated) { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) - return - } - savedCreds := []string{} - if saveCreds && !usePreconfig { - creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}} - if teamID != "" { - creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID}) - } - for _, kv := range creds { - id := findEnvID(envs, kv[0]) - if id != "" { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]}) - } else { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}}) - } - if status == http.StatusOK || status == http.StatusCreated { - savedCreds = append(savedCreds, kv[0]) - } + for _, kv := range creds { + status, _ := upsertVercelEnv(ctx, client, opts.ProjectID, params, headers, envs, kv[0], kv[1]) + if status == http.StatusOK || status == http.StatusCreated { + saved = append(saved, kv[0]) } } - projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) - manual := true - deployURL := "" - if status == http.StatusOK { - if link, ok := projectResp["link"].(map[string]any); ok { - if linkType, _ := link["type"].(string); linkType == "github" { - repoID := intFrom(link["repoId"]) - ref, _ := link["productionBranch"].(string) - if ref == "" { - ref = "main" - } - depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}}) - if depStatus == http.StatusOK || depStatus == http.StatusCreated { - deployURL, _ = depResp["url"].(string) - manual = false - } - } - } + return saved +} + +func triggerVercelDeployment(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string) (bool, string) { + projectResp, status, _ := vercelRequest(ctx, client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) + if status != http.StatusOK { + return true, "" } - _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) - result := map[string]any{"success": true, "validated_accounts": validated} - if manual { - result["message"] = "配置已同步到 Vercel,请手动触发重新部署" - result["manual_deploy_required"] = true - } else { - result["message"] = "配置已同步,正在重新部署..." - result["deployment_url"] = deployURL + link, ok := projectResp["link"].(map[string]any) + if !ok { + return true, "" } - if len(failed) > 0 { - result["failed_accounts"] = failed + linkType, _ := link["type"].(string) + if linkType != "github" { + return true, "" } - if len(savedCreds) > 0 { - result["saved_credentials"] = savedCreds + repoID := intFrom(link["repoId"]) + ref, _ := link["productionBranch"].(string) + if ref == "" { + ref = "main" } - writeJSON(w, http.StatusOK, result) + depResp, depStatus, _ := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{ + "name": projectID, + "project": projectID, + "target": "production", + "gitSource": map[string]any{ + "type": "github", + "repoId": repoID, + "ref": ref, + }, + }) + if depStatus != http.StatusOK && depStatus != http.StatusCreated { + return true, "" + } + deployURL, _ := depResp["url"].(string) + return false, deployURL } func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) { diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index d7d1198..2e00323 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -96,7 +96,7 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool { return acc.Identifier() == id } -func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) { +func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) { id := strings.TrimSpace(identifier) if id == "" { return config.Account{}, false diff --git a/internal/admin/request_error.go b/internal/admin/request_error.go new file mode 100644 index 0000000..5431a3d --- /dev/null +++ b/internal/admin/request_error.go @@ -0,0 +1,23 @@ +package admin + +import "errors" + +type requestError struct { + detail string +} + +func (e *requestError) Error() string { + return e.detail +} + +func newRequestError(detail string) error { + return &requestError{detail: detail} +} + +func requestErrorDetail(err error) (string, bool) { + var reqErr *requestError + if errors.As(err, &reqErr) { + return reqErr.detail, true + } + return "", false +} diff --git a/internal/admin/settings_validation.go b/internal/admin/settings_validation.go new file mode 100644 index 0000000..f9d4c2f --- /dev/null +++ b/internal/admin/settings_validation.go @@ -0,0 +1,64 @@ +package admin + +import ( + "fmt" + "strings" + + "ds2api/internal/config" +) + +func normalizeSettingsConfig(c *config.Config) { + if c == nil { + return + } + c.Admin.PasswordHash = strings.TrimSpace(c.Admin.PasswordHash) + c.Toolcall.Mode = strings.ToLower(strings.TrimSpace(c.Toolcall.Mode)) + c.Toolcall.EarlyEmitConfidence = strings.ToLower(strings.TrimSpace(c.Toolcall.EarlyEmitConfidence)) + c.Embeddings.Provider = strings.TrimSpace(c.Embeddings.Provider) +} + +func validateSettingsConfig(c config.Config) error { + if c.Admin.JWTExpireHours != 0 && (c.Admin.JWTExpireHours < 1 || c.Admin.JWTExpireHours > 720) { + return fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + if err := validateRuntimeSettings(c.Runtime); err != nil { + return err + } + if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) { + return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + if mode := strings.TrimSpace(c.Toolcall.Mode); mode != "" { + switch mode { + case "feature_match", "off": + default: + return fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if level := strings.TrimSpace(c.Toolcall.EarlyEmitConfidence); level != "" { + switch level { + case "high", "low", "off": + default: + return fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" { + return fmt.Errorf("embeddings.provider cannot be empty") + } + return nil +} + +func validateRuntimeSettings(runtime config.RuntimeConfig) error { + if runtime.AccountMaxInflight != 0 && (runtime.AccountMaxInflight < 1 || runtime.AccountMaxInflight > 256) { + return fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + if runtime.AccountMaxQueue != 0 && (runtime.AccountMaxQueue < 1 || runtime.AccountMaxQueue > 200000) { + return fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) { + return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight { + return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + return nil +} diff --git a/internal/auth/admin.go b/internal/auth/admin.go index 3a52f6c..8f1d276 100644 --- a/internal/auth/admin.go +++ b/internal/auth/admin.go @@ -3,7 +3,9 @@ package auth import ( "crypto/hmac" "crypto/sha256" + "crypto/subtle" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "log/slog" @@ -17,7 +19,22 @@ import ( var warnOnce sync.Once +type AdminConfigReader interface { + AdminPasswordHash() string + AdminJWTExpireHours() int + AdminJWTValidAfterUnix() int64 +} + func AdminKey() string { + return effectiveAdminKey(nil) +} + +func effectiveAdminKey(store AdminConfigReader) string { + if store != nil { + if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" { + return "" + } + } if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" { return v } @@ -27,14 +44,24 @@ func AdminKey() string { return "admin" } -func jwtSecret() string { +func jwtSecret(store AdminConfigReader) string { if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" { return v } - return AdminKey() + if store != nil { + if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" { + return hash + } + } + return effectiveAdminKey(store) } -func jwtExpireHours() int { +func jwtExpireHours(store AdminConfigReader) int { + if store != nil { + if n := store.AdminJWTExpireHours(); n > 0 { + return n + } + } if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { return n @@ -44,27 +71,44 @@ func jwtExpireHours() int { } func CreateJWT(expireHours int) (string, error) { + return CreateJWTWithStore(expireHours, nil) +} + +func CreateJWTWithStore(expireHours int, store AdminConfigReader) (string, error) { if expireHours <= 0 { - expireHours = jwtExpireHours() + expireHours = jwtExpireHours(store) } + issuedAt := time.Now().Unix() + // If sessions were invalidated in this same second, move iat forward by + // one second so newly minted tokens remain valid with strict cutoff checks. + if store != nil { + if validAfter := store.AdminJWTValidAfterUnix(); validAfter >= issuedAt { + issuedAt = validAfter + 1 + } + } + expireAt := time.Unix(issuedAt, 0).Add(time.Duration(expireHours) * time.Hour).Unix() header := map[string]any{"alg": "HS256", "typ": "JWT"} - payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"} + payload := map[string]any{"iat": issuedAt, "exp": expireAt, "role": "admin"} h, _ := json.Marshal(header) p, _ := json.Marshal(payload) headerB64 := rawB64Encode(h) payloadB64 := rawB64Encode(p) msg := headerB64 + "." + payloadB64 - sig := signHS256(msg) + sig := signHS256(msg, store) return msg + "." + rawB64Encode(sig), nil } func VerifyJWT(token string) (map[string]any, error) { + return VerifyJWTWithStore(token, nil) +} + +func VerifyJWTWithStore(token string, store AdminConfigReader) (map[string]any, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return nil, errors.New("invalid token format") } msg := parts[0] + "." + parts[1] - expected := signHS256(msg) + expected := signHS256(msg, store) actual, err := rawB64Decode(parts[2]) if err != nil { return nil, errors.New("invalid signature") @@ -84,10 +128,23 @@ func VerifyJWT(token string) (map[string]any, error) { if int64(exp) < time.Now().Unix() { return nil, errors.New("token expired") } + if store != nil { + validAfter := store.AdminJWTValidAfterUnix() + if validAfter > 0 { + iat, _ := payload["iat"].(float64) + if int64(iat) <= validAfter { + return nil, errors.New("token expired") + } + } + } return payload, nil } func VerifyAdminRequest(r *http.Request) error { + return VerifyAdminRequestWithStore(r, nil) +} + +func VerifyAdminRequestWithStore(r *http.Request, store AdminConfigReader) error { authHeader := strings.TrimSpace(r.Header.Get("Authorization")) if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { return errors.New("authentication required") @@ -96,17 +153,65 @@ func VerifyAdminRequest(r *http.Request) error { if token == "" { return errors.New("authentication required") } - if token == AdminKey() { + if VerifyAdminCredential(token, store) { return nil } - if _, err := VerifyJWT(token); err == nil { + if _, err := VerifyJWTWithStore(token, store); err == nil { return nil } return errors.New("invalid credentials") } -func signHS256(msg string) []byte { - h := hmac.New(sha256.New, []byte(jwtSecret())) +func VerifyAdminCredential(candidate string, store AdminConfigReader) bool { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return false + } + if store != nil { + hash := strings.TrimSpace(store.AdminPasswordHash()) + if hash != "" { + return verifyAdminPasswordHash(candidate, hash) + } + } + key := effectiveAdminKey(store) + if key == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(candidate), []byte(key)) == 1 +} + +func UsingDefaultAdminKey(store AdminConfigReader) bool { + if store != nil && strings.TrimSpace(store.AdminPasswordHash()) != "" { + return false + } + return strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) == "" +} + +func HashAdminPassword(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + sum := sha256.Sum256([]byte(raw)) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +func verifyAdminPasswordHash(candidate, encoded string) bool { + encoded = strings.TrimSpace(strings.ToLower(encoded)) + if encoded == "" { + return false + } + if strings.HasPrefix(encoded, "sha256:") { + want := strings.TrimPrefix(encoded, "sha256:") + sum := sha256.Sum256([]byte(candidate)) + got := hex.EncodeToString(sum[:]) + return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1 + } + return subtle.ConstantTimeCompare([]byte(candidate), []byte(encoded)) == 1 +} + +func signHS256(msg string, store AdminConfigReader) []byte { + h := hmac.New(sha256.New, []byte(jwtSecret(store))) _, _ = h.Write([]byte(msg)) return h.Sum(nil) } diff --git a/internal/auth/admin_test.go b/internal/auth/admin_test.go index 7489074..bfbd4c3 100644 --- a/internal/auth/admin_test.go +++ b/internal/auth/admin_test.go @@ -3,6 +3,8 @@ package auth import ( "net/http" "testing" + + "ds2api/internal/config" ) func TestJWTCreateVerify(t *testing.T) { @@ -27,3 +29,58 @@ func TestVerifyAdminRequest(t *testing.T) { t.Fatalf("expected token accepted: %v", err) } } + +func TestVerifyJWTWithStoreValidAfter(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`) + store := config.LoadStore() + token, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + if _, err := VerifyJWTWithStore(token, store); err != nil { + t.Fatalf("verify before invalidation failed: %v", err) + } + if err := store.Update(func(c *config.Config) error { + c.Admin.JWTValidAfterUnix = 1<<62 - 1 + return nil + }); err != nil { + t.Fatalf("set valid-after failed: %v", err) + } + if _, err := VerifyJWTWithStore(token, store); err == nil { + t.Fatal("expected token invalid after valid-after update") + } +} + +func TestVerifyJWTWithStoreSameSecondInvalidationAndRelogin(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`) + store := config.LoadStore() + + oldToken, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create old jwt failed: %v", err) + } + oldPayload, err := VerifyJWTWithStore(oldToken, store) + if err != nil { + t.Fatalf("verify old jwt before invalidation failed: %v", err) + } + oldIAT, _ := oldPayload["iat"].(float64) + + if err := store.Update(func(c *config.Config) error { + c.Admin.JWTValidAfterUnix = int64(oldIAT) + return nil + }); err != nil { + t.Fatalf("set valid-after failed: %v", err) + } + + if _, err := VerifyJWTWithStore(oldToken, store); err == nil { + t.Fatal("expected old token invalid when iat == valid-after") + } + + newToken, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create new jwt failed: %v", err) + } + if _, err := VerifyJWTWithStore(newToken, store); err != nil { + t.Fatalf("expected new token valid after invalidation cutoff: %v", err) + } +} diff --git a/internal/claudeconv/convert.go b/internal/claudeconv/convert.go new file mode 100644 index 0000000..1ce1f01 --- /dev/null +++ b/internal/claudeconv/convert.go @@ -0,0 +1,48 @@ +package claudeconv + +import "strings" + +type ClaudeMappingProvider interface { + ClaudeMapping() map[string]string +} + +func ConvertClaudeToDeepSeek(claudeReq map[string]any, mappingProvider ClaudeMappingProvider, defaultClaudeModel string) map[string]any { + messages, _ := claudeReq["messages"].([]any) + model, _ := claudeReq["model"].(string) + if model == "" { + model = defaultClaudeModel + } + + mapping := map[string]string{} + if mappingProvider != nil { + mapping = mappingProvider.ClaudeMapping() + } + dsModel := mapping["fast"] + if dsModel == "" { + dsModel = "deepseek-chat" + } + + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { + if slow := mapping["slow"]; slow != "" { + dsModel = slow + } + } + + convertedMessages := make([]any, 0, len(messages)+1) + if system, ok := claudeReq["system"].(string); ok && system != "" { + convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) + } + convertedMessages = append(convertedMessages, messages...) + + out := map[string]any{"model": dsModel, "messages": convertedMessages} + for _, k := range []string{"temperature", "top_p", "stream"} { + if v, ok := claudeReq[k]; ok { + out[k] = v + } + } + if stopSeq, ok := claudeReq["stop_sequences"]; ok { + out["stop"] = stopSeq + } + return out +} diff --git a/internal/compat/go_compat_test.go b/internal/compat/go_compat_test.go new file mode 100644 index 0000000..024e7ba --- /dev/null +++ b/internal/compat/go_compat_test.go @@ -0,0 +1,142 @@ +package compat + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func TestGoCompatSSEFixtures(t *testing.T) { + files, err := filepath.Glob(compatPath("fixtures", "sse_chunks", "*.json")) + if err != nil { + t.Fatalf("glob fixtures failed: %v", err) + } + if len(files) == 0 { + t.Fatal("no sse fixtures found") + } + for _, fixturePath := range files { + name := trimExt(filepath.Base(fixturePath)) + expectedPath := compatPath("expected", "sse_"+name+".json") + + var fixture struct { + Chunk map[string]any `json:"chunk"` + ThinkingEnable bool `json:"thinking_enabled"` + CurrentType string `json:"current_type"` + } + mustLoadJSON(t, fixturePath, &fixture) + + var expected struct { + Parts []map[string]any `json:"parts"` + Finished bool `json:"finished"` + NewType string `json:"new_type"` + } + mustLoadJSON(t, expectedPath, &expected) + + parts, finished, newType := sse.ParseSSEChunkForContent(fixture.Chunk, fixture.ThinkingEnable, fixture.CurrentType) + gotParts := make([]map[string]any, 0, len(parts)) + for _, p := range parts { + gotParts = append(gotParts, map[string]any{ + "text": p.Text, + "type": p.Type, + }) + } + if !reflect.DeepEqual(gotParts, expected.Parts) || finished != expected.Finished || newType != expected.NewType { + t.Fatalf("fixture %s mismatch:\n got parts=%#v finished=%v newType=%q\nwant parts=%#v finished=%v newType=%q", + name, gotParts, finished, newType, expected.Parts, expected.Finished, expected.NewType) + } + } +} + +func TestGoCompatToolcallFixtures(t *testing.T) { + files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json")) + if err != nil { + t.Fatalf("glob toolcall fixtures failed: %v", err) + } + if len(files) == 0 { + t.Fatal("no toolcall fixtures found") + } + for _, fixturePath := range files { + name := trimExt(filepath.Base(fixturePath)) + expectedPath := compatPath("expected", "toolcalls_"+name+".json") + + var fixture struct { + Text string `json:"text"` + ToolNames []string `json:"tool_names"` + } + mustLoadJSON(t, fixturePath, &fixture) + + var expected struct { + Calls []util.ParsedToolCall `json:"calls"` + } + mustLoadJSON(t, expectedPath, &expected) + + got := util.ParseToolCalls(fixture.Text, fixture.ToolNames) + if len(got) == 0 && len(expected.Calls) == 0 { + continue + } + if !reflect.DeepEqual(got, expected.Calls) { + t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls) + } + } +} + +func TestGoCompatTokenFixtures(t *testing.T) { + var fixture struct { + Cases []struct { + Name string `json:"name"` + Text string `json:"text"` + } `json:"cases"` + } + mustLoadJSON(t, compatPath("fixtures", "token_cases.json"), &fixture) + + var expected struct { + Cases []struct { + Name string `json:"name"` + Tokens int `json:"tokens"` + } `json:"cases"` + } + mustLoadJSON(t, compatPath("expected", "token_cases.json"), &expected) + + expectByName := map[string]int{} + for _, c := range expected.Cases { + expectByName[c.Name] = c.Tokens + } + for _, c := range fixture.Cases { + want, ok := expectByName[c.Name] + if !ok { + t.Fatalf("missing expected token case: %s", c.Name) + } + got := util.EstimateTokens(c.Text) + if got != want { + t.Fatalf("token fixture %s mismatch: got=%d want=%d", c.Name, got, want) + } + } +} + +func mustLoadJSON(t *testing.T, path string, out any) { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s failed: %v", path, err) + } + if err := json.Unmarshal(b, out); err != nil { + t.Fatalf("decode %s failed: %v", path, err) + } +} + +func trimExt(name string) string { + if len(name) > 5 && name[len(name)-5:] == ".json" { + return name[:len(name)-5] + } + return name +} + +func compatPath(parts ...string) string { + prefix := []string{"..", "..", "tests", "compat"} + return filepath.Join(append(prefix, parts...)...) +} diff --git a/internal/config/config.go b/internal/config/config.go index d391462..3bc0409 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "slices" + "strconv" "strings" "sync" ) @@ -63,6 +64,8 @@ type Config struct { ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` ModelAliases map[string]string `json:"model_aliases,omitempty"` + Admin AdminConfig `json:"admin,omitempty"` + Runtime RuntimeConfig `json:"runtime,omitempty"` Compat CompatConfig `json:"compat,omitempty"` Toolcall ToolcallConfig `json:"toolcall,omitempty"` Responses ResponsesConfig `json:"responses,omitempty"` @@ -76,6 +79,18 @@ type CompatConfig struct { WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` } +type AdminConfig struct { + PasswordHash string `json:"password_hash,omitempty"` + JWTExpireHours int `json:"jwt_expire_hours,omitempty"` + JWTValidAfterUnix int64 `json:"jwt_valid_after_unix,omitempty"` +} + +type RuntimeConfig struct { + AccountMaxInflight int `json:"account_max_inflight,omitempty"` + AccountMaxQueue int `json:"account_max_queue,omitempty"` + GlobalMaxInflight int `json:"global_max_inflight,omitempty"` +} + type ToolcallConfig struct { Mode string `json:"mode,omitempty"` EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"` @@ -109,6 +124,12 @@ func (c Config) MarshalJSON() ([]byte, error) { if len(c.ModelAliases) > 0 { m["model_aliases"] = c.ModelAliases } + if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 { + m["admin"] = c.Admin + } + if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 { + m["runtime"] = c.Runtime + } if c.Compat.WideInputStrictOutput != nil { m["compat"] = c.Compat } @@ -158,6 +179,14 @@ func (c *Config) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(v, &c.ModelAliases); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) } + case "admin": + if err := json.Unmarshal(v, &c.Admin); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "runtime": + if err := json.Unmarshal(v, &c.Runtime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "compat": if err := json.Unmarshal(v, &c.Compat); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) @@ -199,6 +228,8 @@ func (c Config) Clone() Config { ClaudeMapping: cloneStringMap(c.ClaudeMapping), ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), ModelAliases: cloneStringMap(c.ModelAliases), + Admin: c.Admin, + Runtime: c.Runtime, Compat: CompatConfig{ WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), }, @@ -621,3 +652,92 @@ func (s *Store) EmbeddingsProvider() string { defer s.mu.RUnlock() return strings.TrimSpace(s.cfg.Embeddings.Provider) } + +func (s *Store) AdminPasswordHash() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Admin.PasswordHash) +} + +func (s *Store) AdminJWTExpireHours() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Admin.JWTExpireHours > 0 { + return s.cfg.Admin.JWTExpireHours + } + if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" { + if n, err := strconv.Atoi(raw); err == nil && n > 0 { + return n + } + } + return 24 +} + +func (s *Store) AdminJWTValidAfterUnix() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Admin.JWTValidAfterUnix +} + +func (s *Store) RuntimeAccountMaxInflight() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxInflight > 0 { + return s.cfg.Runtime.AccountMaxInflight + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + return 2 +} + +func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxQueue > 0 { + return s.cfg.Runtime.AccountMaxQueue + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n >= 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} + +func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.GlobalMaxInflight > 0 { + return s.cfg.Runtime.GlobalMaxInflight + } + for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} diff --git a/internal/config/models.go b/internal/config/models.go index 017a2ee..a2ec899 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -10,6 +10,10 @@ type ModelInfo struct { Permission []any `json:"permission,omitempty"` } +type ModelAliasReader interface { + ModelAliases() map[string]string +} + var DeepSeekModels = []ModelInfo{ {ID: "deepseek-chat", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, {ID: "deepseek-reasoner", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, @@ -104,7 +108,7 @@ func DefaultModelAliases() map[string]string { } } -func ResolveModel(store *Store, requested string) (string, bool) { +func ResolveModel(store ModelAliasReader, requested string) (string, bool) { model := lower(strings.TrimSpace(requested)) if model == "" { return "", false @@ -172,7 +176,7 @@ func OpenAIModelsResponse() map[string]any { return map[string]any{"object": "list", "data": DeepSeekModels} } -func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) { +func OpenAIModelByID(store ModelAliasReader, id string) (ModelInfo, bool) { canonical, ok := ResolveModel(store, id) if !ok { return ModelInfo{}, false diff --git a/internal/deepseek/constants.go b/internal/deepseek/constants.go index 1e7d25f..042ec29 100644 --- a/internal/deepseek/constants.go +++ b/internal/deepseek/constants.go @@ -1,5 +1,10 @@ package deepseek +import ( + _ "embed" + "encoding/json" +) + const ( DeepSeekHost = "chat.deepseek.com" DeepSeekLoginURL = "https://chat.deepseek.com/api/v0/users/login" @@ -8,7 +13,7 @@ const ( DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion" ) -var BaseHeaders = map[string]string{ +var defaultBaseHeaders = map[string]string{ "Host": "chat.deepseek.com", "User-Agent": "DeepSeek/1.6.11 Android/35", "Accept": "application/json", @@ -19,6 +24,75 @@ var BaseHeaders = map[string]string{ "accept-charset": "UTF-8", } +var defaultSkipContainsPatterns = []string{ + "quasi_status", + "elapsed_secs", + "token_usage", + "pending_fragment", + "conversation_mode", + "fragments/-1/status", + "fragments/-2/status", + "fragments/-3/status", +} + +var defaultSkipExactPaths = []string{ + "response/search_status", +} + +var BaseHeaders = cloneStringMap(defaultBaseHeaders) +var SkipContainsPatterns = cloneStringSlice(defaultSkipContainsPatterns) +var SkipExactPathSet = toStringSet(defaultSkipExactPaths) + +type sharedConstants struct { + BaseHeaders map[string]string `json:"base_headers"` + SkipContainsPattern []string `json:"skip_contains_patterns"` + SkipExactPaths []string `json:"skip_exact_paths"` +} + +//go:embed constants_shared.json +var sharedConstantsJSON []byte + +func init() { + cfg := sharedConstants{} + if err := json.Unmarshal(sharedConstantsJSON, &cfg); err != nil { + return + } + if len(cfg.BaseHeaders) > 0 { + BaseHeaders = cloneStringMap(cfg.BaseHeaders) + } + if len(cfg.SkipContainsPattern) > 0 { + SkipContainsPatterns = cloneStringSlice(cfg.SkipContainsPattern) + } + if len(cfg.SkipExactPaths) > 0 { + SkipExactPathSet = toStringSet(cfg.SkipExactPaths) + } +} + +func cloneStringMap(in map[string]string) map[string]string { + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringSlice(in []string) []string { + out := make([]string, len(in)) + copy(out, in) + return out +} + +func toStringSet(in []string) map[string]struct{} { + out := make(map[string]struct{}, len(in)) + for _, v := range in { + if v == "" { + continue + } + out[v] = struct{}{} + } + return out +} + const ( KeepAliveTimeout = 5 StreamIdleTimeout = 30 diff --git a/internal/deepseek/constants_shared.json b/internal/deepseek/constants_shared.json new file mode 100644 index 0000000..a71ca02 --- /dev/null +++ b/internal/deepseek/constants_shared.json @@ -0,0 +1,25 @@ +{ + "base_headers": { + "Host": "chat.deepseek.com", + "User-Agent": "DeepSeek/1.6.11 Android/35", + "Accept": "application/json", + "Content-Type": "application/json", + "x-client-platform": "android", + "x-client-version": "1.6.11", + "x-client-locale": "zh_CN", + "accept-charset": "UTF-8" + }, + "skip_contains_patterns": [ + "quasi_status", + "elapsed_secs", + "token_usage", + "pending_fragment", + "conversation_mode", + "fragments/-1/status", + "fragments/-2/status", + "fragments/-3/status" + ], + "skip_exact_paths": [ + "response/search_status" + ] +} diff --git a/internal/deepseek/constants_test.go b/internal/deepseek/constants_test.go new file mode 100644 index 0000000..03c6788 --- /dev/null +++ b/internal/deepseek/constants_test.go @@ -0,0 +1,15 @@ +package deepseek + +import "testing" + +func TestSharedConstantsLoaded(t *testing.T) { + if BaseHeaders["x-client-platform"] != "android" { + t.Fatalf("unexpected base header x-client-platform=%q", BaseHeaders["x-client-platform"]) + } + if len(SkipContainsPatterns) == 0 { + t.Fatal("expected skip contains patterns to be loaded") + } + if _, ok := SkipExactPathSet["response/search_status"]; !ok { + t.Fatal("expected response/search_status in exact skip path set") + } +} diff --git a/internal/deepseek/prompt.go b/internal/deepseek/prompt.go new file mode 100644 index 0000000..2410390 --- /dev/null +++ b/internal/deepseek/prompt.go @@ -0,0 +1,7 @@ +package deepseek + +import "ds2api/internal/prompt" + +func MessagesPrepare(messages []map[string]any) string { + return prompt.MessagesPrepare(messages) +} diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go new file mode 100644 index 0000000..fdba055 --- /dev/null +++ b/internal/format/claude/render.go @@ -0,0 +1,46 @@ +package claude + +import ( + "fmt" + "time" + + "ds2api/internal/util" +) + +func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + content := make([]map[string]any, 0, 4) + if finalThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if finalText == "" { + finalText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": finalText}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), + "output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText), + }, + } +} diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go new file mode 100644 index 0000000..fc7473f --- /dev/null +++ b/internal/format/openai/render.go @@ -0,0 +1,193 @@ +package openai + +import ( + "strings" + "time" + + "github.com/google/uuid" + + "ds2api/internal/util" +) + +func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + } +} + +func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + exposedOutputText := finalText + output := make([]any, 0, 2) + if len(detected) > 0 { + exposedOutputText = "" + toolCalls := make([]any, 0, len(detected)) + for _, tc := range detected { + toolCalls = append(toolCalls, map[string]any{ + "type": "tool_call", + "name": tc.Name, + "arguments": tc.Input, + }) + } + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": toolCalls, + }) + } else { + content := []any{ + map[string]any{ + "type": "output_text", + "text": finalText, + }, + } + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": exposedOutputText, + "usage": map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + }, + } +} + +func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { + return map[string]any{ + "delta": delta, + "index": index, + } +} + +func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "index": index, + "finish_reason": finishReason, + } +} + +func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { + out := map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + } + if len(usage) > 0 { + out["usage"] = usage + } + return out +} + +func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + } +} + +func BuildResponsesCreatedPayload(responseID, model string) map[string]any { + return map[string]any{ + "type": "response.created", + "id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + } +} + +func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.delta", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesCompletedPayload(response map[string]any) map[string]any { + return map[string]any{ + "type": "response.completed", + "response": response, + } +} diff --git a/internal/prompt/messages.go b/internal/prompt/messages.go new file mode 100644 index 0000000..69cfe5a --- /dev/null +++ b/internal/prompt/messages.go @@ -0,0 +1,84 @@ +package prompt + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) + +func MessagesPrepare(messages []map[string]any) string { + type block struct { + Role string + Text string + } + processed := make([]block, 0, len(messages)) + for _, m := range messages { + role, _ := m["role"].(string) + text := NormalizeContent(m["content"]) + processed = append(processed, block{Role: role, Text: text}) + } + if len(processed) == 0 { + return "" + } + merged := make([]block, 0, len(processed)) + for _, msg := range processed { + if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { + merged[len(merged)-1].Text += "\n\n" + msg.Text + continue + } + merged = append(merged, msg) + } + parts := make([]string, 0, len(merged)) + for i, m := range merged { + switch m.Role { + case "assistant": + parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") + case "user", "system": + if i > 0 { + parts = append(parts, "<|User|>"+m.Text) + } else { + parts = append(parts, m.Text) + } + default: + parts = append(parts, m.Text) + } + } + out := strings.Join(parts, "") + return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) +} + +func NormalizeContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + typeStr, _ := m["type"].(string) + typeStr = strings.ToLower(strings.TrimSpace(typeStr)) + if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { + if txt, ok := m["text"].(string); ok { + parts = append(parts, txt) + continue + } + if txt, ok := m["content"].(string); ok { + parts = append(parts, txt) + } + } + } + return strings.Join(parts, "\n") + default: + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) + } +} diff --git a/internal/sse/parser.go b/internal/sse/parser.go index 38429d9..c20bc79 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "strings" + + "ds2api/internal/deepseek" ) type ContentPart struct { @@ -11,11 +13,6 @@ type ContentPart struct { Type string } -var skipPatterns = []string{ - "quasi_status", "elapsed_secs", "token_usage", "pending_fragment", "conversation_mode", - "fragments/-1/status", "fragments/-2/status", "fragments/-3/status", -} - func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { line := strings.TrimSpace(string(raw)) if line == "" || !strings.HasPrefix(line, "data:") { @@ -33,10 +30,10 @@ func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { } func shouldSkipPath(path string) bool { - if path == "response/search_status" { + if _, ok := deepseek.SkipExactPathSet[path]; ok { return true } - for _, p := range skipPatterns { + for _, p := range deepseek.SkipContainsPatterns { if strings.Contains(path, p) { return true } @@ -60,126 +57,159 @@ func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, current } newType := currentFragmentType parts := make([]ContentPart, 0, 8) + collectDirectFragments(path, chunk, v, &newType, &parts) + updateTypeFromNestedResponse(path, v, &newType) + partType := resolvePartType(path, thinkingEnabled, newType) + finished := appendChunkValueContent(v, partType, &newType, &parts, path) + if finished { + return nil, true, newType + } + return parts, false, newType +} - // Newer DeepSeek responses may emit fragment APPEND directly on - // path "response/fragments" instead of wrapping it in path "response". - if path == "response/fragments" { - if op, _ := chunk["o"].(string); strings.EqualFold(op, "APPEND") { - if frags, ok := v.([]any); ok { - for _, frag := range frags { - fm, ok := frag.(map[string]any) - if !ok { - continue - } - t, _ := fm["type"].(string) - content, _ := fm["content"].(string) - t = strings.ToUpper(t) - switch t { - case "THINK", "THINKING": - newType = "thinking" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "thinking"}) - } - case "RESPONSE": - newType = "text" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - default: - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - } - } +func collectDirectFragments(path string, chunk map[string]any, v any, newType *string, parts *[]ContentPart) { + if path != "response/fragments" { + return + } + op, _ := chunk["o"].(string) + if !strings.EqualFold(op, "APPEND") { + return + } + frags, ok := v.([]any) + if !ok { + return + } + for _, frag := range frags { + m, ok := frag.(map[string]any) + if !ok { + continue + } + typeName, content, fragType := parseFragmentTypeContent(m) + if typeName == "" { + typeName = fragType + } + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + appendContentPart(parts, content, "thinking") + case "RESPONSE": + *newType = "text" + appendContentPart(parts, content, "text") + default: + appendContentPart(parts, content, "text") + } + } +} + +func updateTypeFromNestedResponse(path string, v any, newType *string) { + if path != "response" { + return + } + arr, ok := v.([]any) + if !ok { + return + } + for _, it := range arr { + m, ok := it.(map[string]any) + if !ok || m["p"] != "fragments" || m["o"] != "APPEND" { + continue + } + frags, ok := m["v"].([]any) + if !ok { + continue + } + for _, frag := range frags { + fm, ok := frag.(map[string]any) + if !ok { + continue + } + typeName, _, _ := parseFragmentTypeContent(fm) + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + case "RESPONSE": + *newType = "text" } } } +} - if path == "response" { - if arr, ok := v.([]any); ok { - for _, it := range arr { - m, ok := it.(map[string]any) - if !ok { - continue - } - if m["p"] == "fragments" && m["o"] == "APPEND" { - if frags, ok := m["v"].([]any); ok { - for _, frag := range frags { - fm, ok := frag.(map[string]any) - if !ok { - continue - } - t, _ := fm["type"].(string) - t = strings.ToUpper(t) - if t == "THINK" || t == "THINKING" { - newType = "thinking" - } else if t == "RESPONSE" { - newType = "text" - } - } - } - } - } - } - } - partType := "text" +func resolvePartType(path string, thinkingEnabled bool, newType string) string { switch { case path == "response/thinking_content": - partType = "thinking" + return "thinking" case path == "response/content": - partType = "text" + return "text" case strings.Contains(path, "response/fragments") && strings.Contains(path, "/content"): - partType = newType - case path == "": - if thinkingEnabled { - partType = newType - } + return newType + case path == "" && thinkingEnabled: + return newType + default: + return "text" } +} + +func appendChunkValueContent(v any, partType string, newType *string, parts *[]ContentPart, path string) bool { switch val := v.(type) { case string: if val == "FINISHED" && (path == "" || path == "status") { - return nil, true, newType - } - if val != "" { - parts = append(parts, ContentPart{Text: val, Type: partType}) + return true } + appendContentPart(parts, val, partType) case []any: pp, finished := extractContentRecursive(val, partType) if finished { - return nil, true, newType + return true } - parts = append(parts, pp...) + *parts = append(*parts, pp...) case map[string]any: - resp := val - if wrapped, ok := val["response"].(map[string]any); ok { - resp = wrapped + appendWrappedFragments(val, partType, newType, parts) + } + return false +} + +func appendWrappedFragments(val map[string]any, partType string, newType *string, parts *[]ContentPart) { + resp := val + if wrapped, ok := val["response"].(map[string]any); ok { + resp = wrapped + } + frags, ok := resp["fragments"].([]any) + if !ok { + return + } + for _, item := range frags { + m, ok := item.(map[string]any) + if !ok { + continue } - if frags, ok := resp["fragments"].([]any); ok { - for _, item := range frags { - m, ok := item.(map[string]any) - if !ok { - continue - } - t, _ := m["type"].(string) - content, _ := m["content"].(string) - t = strings.ToUpper(t) - if t == "THINK" || t == "THINKING" { - newType = "thinking" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "thinking"}) - } - } else if t == "RESPONSE" { - newType = "text" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - } else if content != "" { - parts = append(parts, ContentPart{Text: content, Type: partType}) - } - } + typeName, content, fragType := parseFragmentTypeContent(m) + if typeName == "" { + typeName = fragType + } + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + appendContentPart(parts, content, "thinking") + case "RESPONSE": + *newType = "text" + appendContentPart(parts, content, "text") + default: + appendContentPart(parts, content, partType) } } - return parts, false, newType +} + +func parseFragmentTypeContent(m map[string]any) (string, string, string) { + typeName, _ := m["type"].(string) + content, _ := m["content"].(string) + return strings.ToUpper(typeName), content, strings.ToUpper(typeName) +} + +func appendContentPart(parts *[]ContentPart, content, kind string) { + if content == "" { + return + } + *parts = append(*parts, ContentPart{Text: content, Type: kind}) } func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bool) { diff --git a/internal/stream/engine.go b/internal/stream/engine.go new file mode 100644 index 0000000..c63cd7b --- /dev/null +++ b/internal/stream/engine.go @@ -0,0 +1,128 @@ +package stream + +import ( + "context" + "io" + "time" + + "ds2api/internal/sse" +) + +type StopReason string + +const ( + StopReasonNone StopReason = "" + StopReasonContextCancelled StopReason = "context_cancelled" + StopReasonNoContentTimeout StopReason = "no_content_timeout" + StopReasonIdleTimeout StopReason = "idle_timeout" + StopReasonUpstreamCompleted StopReason = "upstream_completed" + StopReasonHandlerRequested StopReason = "handler_requested" +) + +type ConsumeConfig struct { + Context context.Context + Body io.Reader + ThinkingEnabled bool + InitialType string + KeepAliveInterval time.Duration + IdleTimeout time.Duration + MaxKeepAliveNoInput int +} + +type ParsedDecision struct { + Stop bool + StopReason StopReason + ContentSeen bool +} + +type ConsumeHooks struct { + OnParsed func(parsed sse.LineResult) ParsedDecision + OnKeepAlive func() + OnFinalize func(reason StopReason, scannerErr error) + OnContextDone func() +} + +func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) { + if cfg.Context == nil { + cfg.Context = context.Background() + } + initialType := cfg.InitialType + if initialType == "" { + if cfg.ThinkingEnabled { + initialType = "thinking" + } else { + initialType = "text" + } + } + parsedLines, done := sse.StartParsedLinePump(cfg.Context, cfg.Body, cfg.ThinkingEnabled, initialType) + + var ticker *time.Ticker + if cfg.KeepAliveInterval > 0 { + ticker = time.NewTicker(cfg.KeepAliveInterval) + defer ticker.Stop() + } + + hasContent := false + lastContent := time.Now() + keepaliveCount := 0 + + finalize := func(reason StopReason, scannerErr error) { + if hooks.OnFinalize != nil { + hooks.OnFinalize(reason, scannerErr) + } + } + + for { + select { + case <-cfg.Context.Done(): + if hooks.OnContextDone != nil { + hooks.OnContextDone() + } + return + case <-tickCh(ticker): + if !hasContent { + keepaliveCount++ + if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput { + finalize(StopReasonNoContentTimeout, nil) + return + } + } + if hasContent && cfg.IdleTimeout > 0 && time.Since(lastContent) > cfg.IdleTimeout { + finalize(StopReasonIdleTimeout, nil) + return + } + if hooks.OnKeepAlive != nil { + hooks.OnKeepAlive() + } + case parsed, ok := <-parsedLines: + if !ok { + finalize(StopReasonUpstreamCompleted, <-done) + return + } + if hooks.OnParsed == nil { + continue + } + decision := hooks.OnParsed(parsed) + if decision.ContentSeen { + hasContent = true + lastContent = time.Now() + keepaliveCount = 0 + } + if decision.Stop { + reason := decision.StopReason + if reason == StopReasonNone { + reason = StopReasonHandlerRequested + } + finalize(reason, nil) + return + } + } + } +} + +func tickCh(ticker *time.Ticker) <-chan time.Time { + if ticker == nil { + return nil + } + return ticker.C +} diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go index e6ae9a6..33e7580 100644 --- a/internal/testsuite/runner.go +++ b/internal/testsuite/runner.go @@ -327,7 +327,7 @@ func (r *Runner) runPreflight(ctx context.Context) error { {"go", "test", "./...", "-count=1"}, {"node", "--check", "api/chat-stream.js"}, {"node", "--check", "api/helpers/stream-tool-sieve.js"}, - {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js"}, + {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"}, {"npm", "run", "build", "--prefix", "webui"}, } f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) diff --git a/internal/util/messages.go b/internal/util/messages.go index fcc9484..b6920c0 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -1,16 +1,11 @@ package util import ( - "encoding/json" - "fmt" - "regexp" - "strings" - + "ds2api/internal/claudeconv" "ds2api/internal/config" + "ds2api/internal/prompt" ) -var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) - const ClaudeDefaultModel = "claude-sonnet-4-5" type Message struct { @@ -19,112 +14,15 @@ type Message struct { } func MessagesPrepare(messages []map[string]any) string { - type block struct { - Role string - Text string - } - processed := make([]block, 0, len(messages)) - for _, m := range messages { - role, _ := m["role"].(string) - text := normalizeContent(m["content"]) - processed = append(processed, block{Role: role, Text: text}) - } - if len(processed) == 0 { - return "" - } - merged := make([]block, 0, len(processed)) - for _, msg := range processed { - if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { - merged[len(merged)-1].Text += "\n\n" + msg.Text - continue - } - merged = append(merged, msg) - } - parts := make([]string, 0, len(merged)) - for i, m := range merged { - switch m.Role { - case "assistant": - parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") - case "user", "system": - if i > 0 { - parts = append(parts, "<|User|>"+m.Text) - } else { - parts = append(parts, m.Text) - } - default: - parts = append(parts, m.Text) - } - } - out := strings.Join(parts, "") - return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) + return prompt.MessagesPrepare(messages) } func normalizeContent(v any) string { - switch x := v.(type) { - case string: - return x - case []any: - parts := make([]string, 0, len(x)) - for _, item := range x { - m, ok := item.(map[string]any) - if !ok { - continue - } - typeStr, _ := m["type"].(string) - typeStr = strings.ToLower(strings.TrimSpace(typeStr)) - if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { - if txt, ok := m["text"].(string); ok { - parts = append(parts, txt) - continue - } - if txt, ok := m["content"].(string); ok { - parts = append(parts, txt) - } - } - } - return strings.Join(parts, "\n") - default: - b, err := json.Marshal(v) - if err != nil { - return fmt.Sprintf("%v", v) - } - return string(b) - } + return prompt.NormalizeContent(v) } func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[string]any { - messages, _ := claudeReq["messages"].([]any) - model, _ := claudeReq["model"].(string) - if model == "" { - model = ClaudeDefaultModel - } - mapping := store.ClaudeMapping() - dsModel := mapping["fast"] - if dsModel == "" { - dsModel = "deepseek-chat" - } - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { - if slow := mapping["slow"]; slow != "" { - dsModel = slow - } - } - convertedMessages := make([]any, 0, len(messages)+1) - if system, ok := claudeReq["system"].(string); ok && system != "" { - convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) - } - convertedMessages = append(convertedMessages, messages...) - - out := map[string]any{"model": dsModel, "messages": convertedMessages} - for _, k := range []string{"temperature", "top_p", "stream"} { - if v, ok := claudeReq[k]; ok { - out[k] = v - } - } - if stopSeq, ok := claudeReq["stop_sequences"]; ok { - out["stop"] = stopSeq - } - return out + return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, ClaudeDefaultModel) } // EstimateTokens provides a rough token count approximation. diff --git a/internal/util/render.go b/internal/util/render.go index b5e0a79..fff8501 100644 --- a/internal/util/render.go +++ b/internal/util/render.go @@ -8,6 +8,8 @@ import ( "github.com/google/uuid" ) +// BuildOpenAIChatCompletion is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatCompletion for new code. func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) finishReason := "stop" @@ -41,6 +43,8 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, } } +// BuildOpenAIResponseObject is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponseObject for new code. func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) exposedOutputText := finalText @@ -101,6 +105,8 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi } } +// BuildClaudeMessageResponse is kept for backward compatibility. +// Prefer internal/format/claude.BuildMessageResponse for new code. func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) content := make([]map[string]any, 0, 4) diff --git a/internal/util/render_stream.go b/internal/util/render_stream.go index 716c158..b5699ba 100644 --- a/internal/util/render_stream.go +++ b/internal/util/render_stream.go @@ -1,5 +1,7 @@ package util +// BuildOpenAIChatStreamDeltaChoice is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatStreamDeltaChoice for new code. func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { return map[string]any{ "delta": delta, @@ -7,6 +9,8 @@ func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[strin } } +// BuildOpenAIChatStreamFinishChoice is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatStreamFinishChoice for new code. func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any { return map[string]any{ "delta": map[string]any{}, @@ -15,6 +19,8 @@ func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[strin } } +// BuildOpenAIChatStreamChunk is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatStreamChunk for new code. func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { out := map[string]any{ "id": completionID, @@ -29,6 +35,8 @@ func BuildOpenAIChatStreamChunk(completionID string, created int64, model string return out } +// BuildOpenAIChatUsage is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatUsage for new code. func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { promptTokens := EstimateTokens(finalPrompt) reasoningTokens := EstimateTokens(finalThinking) @@ -43,6 +51,8 @@ func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[stri } } +// BuildOpenAIResponsesCreatedPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesCreatedPayload for new code. func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any { return map[string]any{ "type": "response.created", @@ -53,6 +63,8 @@ func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any } } +// BuildOpenAIResponsesTextDeltaPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesTextDeltaPayload for new code. func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ "type": "response.output_text.delta", @@ -61,6 +73,8 @@ func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]a } } +// BuildOpenAIResponsesReasoningDeltaPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesReasoningDeltaPayload for new code. func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ "type": "response.reasoning.delta", @@ -69,6 +83,8 @@ func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[str } } +// BuildOpenAIResponsesToolCallDeltaPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesToolCallDeltaPayload for new code. func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { return map[string]any{ "type": "response.output_tool_call.delta", @@ -77,6 +93,8 @@ func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map } } +// BuildOpenAIResponsesToolCallDonePayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesToolCallDonePayload for new code. func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { return map[string]any{ "type": "response.output_tool_call.done", @@ -85,6 +103,8 @@ func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[ } } +// BuildOpenAIResponsesCompletedPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesCompletedPayload for new code. func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any { return map[string]any{ "type": "response.completed", diff --git a/tests/compat/expected/sse_fragments_append.json b/tests/compat/expected/sse_fragments_append.json new file mode 100644 index 0000000..8647f3a --- /dev/null +++ b/tests/compat/expected/sse_fragments_append.json @@ -0,0 +1,8 @@ +{ + "parts": [ + {"text": "思考中", "type": "thinking"}, + {"text": "结论", "type": "text"} + ], + "finished": false, + "new_type": "text" +} diff --git a/tests/compat/expected/sse_nested_finished.json b/tests/compat/expected/sse_nested_finished.json new file mode 100644 index 0000000..7d588f7 --- /dev/null +++ b/tests/compat/expected/sse_nested_finished.json @@ -0,0 +1,5 @@ +{ + "parts": [], + "finished": true, + "new_type": "text" +} diff --git a/tests/compat/expected/sse_split_tool_json.json b/tests/compat/expected/sse_split_tool_json.json new file mode 100644 index 0000000..2afed2a --- /dev/null +++ b/tests/compat/expected/sse_split_tool_json.json @@ -0,0 +1,8 @@ +{ + "parts": [ + {"text": "{\"", "type": "text"}, + {"text": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", "type": "text"} + ], + "finished": false, + "new_type": "text" +} diff --git a/tests/compat/expected/token_cases.json b/tests/compat/expected/token_cases.json new file mode 100644 index 0000000..69694eb --- /dev/null +++ b/tests/compat/expected/token_cases.json @@ -0,0 +1,7 @@ +{ + "cases": [ + {"name": "ascii_short", "tokens": 1}, + {"name": "cjk", "tokens": 3}, + {"name": "mixed", "tokens": 4} + ] +} diff --git a/tests/compat/expected/toolcalls_fenced_json.json b/tests/compat/expected/toolcalls_fenced_json.json new file mode 100644 index 0000000..97646bf --- /dev/null +++ b/tests/compat/expected/toolcalls_fenced_json.json @@ -0,0 +1,3 @@ +{ + "calls": [] +} diff --git a/tests/compat/expected/toolcalls_unknown_name.json b/tests/compat/expected/toolcalls_unknown_name.json new file mode 100644 index 0000000..8f79875 --- /dev/null +++ b/tests/compat/expected/toolcalls_unknown_name.json @@ -0,0 +1,5 @@ +{ + "calls": [ + {"name": "unknown_tool", "input": {"x": 1}} + ] +} diff --git a/tests/compat/fixtures/sse_chunks/fragments_append.json b/tests/compat/fixtures/sse_chunks/fragments_append.json new file mode 100644 index 0000000..c6f8ae6 --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/fragments_append.json @@ -0,0 +1,12 @@ +{ + "chunk": { + "p": "response/fragments", + "o": "APPEND", + "v": [ + {"type": "THINK", "content": "思考中"}, + {"type": "RESPONSE", "content": "结论"} + ] + }, + "thinking_enabled": true, + "current_type": "thinking" +} diff --git a/tests/compat/fixtures/sse_chunks/nested_finished.json b/tests/compat/fixtures/sse_chunks/nested_finished.json new file mode 100644 index 0000000..da76280 --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/nested_finished.json @@ -0,0 +1,10 @@ +{ + "chunk": { + "p": "response", + "v": [ + {"p": "status", "v": "FINISHED"} + ] + }, + "thinking_enabled": false, + "current_type": "text" +} diff --git a/tests/compat/fixtures/sse_chunks/split_tool_json.json b/tests/compat/fixtures/sse_chunks/split_tool_json.json new file mode 100644 index 0000000..e915fbb --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/split_tool_json.json @@ -0,0 +1,11 @@ +{ + "chunk": { + "p": "response", + "v": [ + {"p": "response/content", "v": "{\""}, + {"p": "response/content", "v": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"} + ] + }, + "thinking_enabled": false, + "current_type": "text" +} diff --git a/tests/compat/fixtures/token_cases.json b/tests/compat/fixtures/token_cases.json new file mode 100644 index 0000000..3887356 --- /dev/null +++ b/tests/compat/fixtures/token_cases.json @@ -0,0 +1,7 @@ +{ + "cases": [ + {"name": "ascii_short", "text": "abcd"}, + {"name": "cjk", "text": "你好世界"}, + {"name": "mixed", "text": "Hello 你好世界"} + ] +} diff --git a/tests/compat/fixtures/toolcalls/fenced_json.json b/tests/compat/fixtures/toolcalls/fenced_json.json new file mode 100644 index 0000000..8d75cc1 --- /dev/null +++ b/tests/compat/fixtures/toolcalls/fenced_json.json @@ -0,0 +1,4 @@ +{ + "text": "```json\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}\n```", + "tool_names": ["read_file"] +} diff --git a/tests/compat/fixtures/toolcalls/unknown_name.json b/tests/compat/fixtures/toolcalls/unknown_name.json new file mode 100644 index 0000000..0ba9e76 --- /dev/null +++ b/tests/compat/fixtures/toolcalls/unknown_name.json @@ -0,0 +1,4 @@ +{ + "text": "{\"tool_calls\":[{\"name\":\"unknown_tool\",\"input\":{\"x\":1}}]}", + "tool_names": ["read_file"] +} diff --git a/webui/src/App.jsx b/webui/src/App.jsx index 53d0b4a..3f6ad27 100644 --- a/webui/src/App.jsx +++ b/webui/src/App.jsx @@ -11,6 +11,7 @@ import { Key, Upload, Cloud, + Settings as SettingsIcon, LogOut, Menu, X, @@ -23,12 +24,13 @@ import AccountManager from './components/AccountManager' import ApiTester from './components/ApiTester' import BatchImport from './components/BatchImport' import VercelSync from './components/VercelSync' +import Settings from './components/Settings' import Login from './components/Login' import LandingPage from './components/LandingPage' import LanguageToggle from './components/LanguageToggle' import { useI18n } from './i18n' -function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message }) { +function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) { const { t } = useI18n() const [activeTab, setActiveTab] = useState('accounts') const [sidebarOpen, setSidebarOpen] = useState(false) @@ -39,6 +41,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message { id: 'test', label: t('nav.test.label'), icon: Server, description: t('nav.test.desc') }, { id: 'import', label: t('nav.import.label'), icon: Upload, description: t('nav.import.desc') }, { id: 'vercel', label: t('nav.vercel.label'), icon: Cloud, description: t('nav.vercel.desc') }, + { id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') }, ] const authFetch = async (url, options = {}) => { @@ -65,6 +68,8 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message return case 'vercel': return + case 'settings': + return default: return null } @@ -314,6 +319,7 @@ export default function App() { fetchConfig={fetchConfig} showMessage={showMessage} message={message} + onForceLogout={handleLogout} /> ) : (
diff --git a/webui/src/components/Settings.jsx b/webui/src/components/Settings.jsx new file mode 100644 index 0000000..b257ed5 --- /dev/null +++ b/webui/src/components/Settings.jsx @@ -0,0 +1,376 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' +import { AlertTriangle, Download, Lock, Save, Upload } from 'lucide-react' +import { useI18n } from '../i18n' + +export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout }) { + const { t } = useI18n() + const apiFetch = authFetch || fetch + + const [loading, setLoading] = useState(false) + const [saving, setSaving] = useState(false) + const [changingPassword, setChangingPassword] = useState(false) + const [importing, setImporting] = useState(false) + const [exportData, setExportData] = useState(null) + const [importMode, setImportMode] = useState('merge') + const [importText, setImportText] = useState('') + const [newPassword, setNewPassword] = useState('') + const [settingsMeta, setSettingsMeta] = useState({ default_password_warning: false, env_backed: false, needs_vercel_sync: false }) + + const [form, setForm] = useState({ + admin: { jwt_expire_hours: 24 }, + runtime: { account_max_inflight: 2, account_max_queue: 10, global_max_inflight: 10 }, + toolcall: { mode: 'feature_match', early_emit_confidence: 'high' }, + responses: { store_ttl_seconds: 900 }, + embeddings: { provider: '' }, + claude_mapping_text: '{\n "fast": "deepseek-chat",\n "slow": "deepseek-reasoner"\n}', + model_aliases_text: '{}', + }) + + const parseJSONMap = (raw, fieldName) => { + const text = String(raw || '').trim() + if (!text) { + return {} + } + let parsed + try { + parsed = JSON.parse(text) + } catch (_e) { + throw new Error(t('settings.invalidJsonField', { field: fieldName })) + } + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error(t('settings.invalidJsonField', { field: fieldName })) + } + return parsed + } + + const loadSettings = useCallback(async () => { + setLoading(true) + try { + const res = await apiFetch('/admin/settings') + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.loadFailed')) + return + } + setSettingsMeta({ + default_password_warning: Boolean(data.admin?.default_password_warning), + env_backed: Boolean(data.env_backed), + needs_vercel_sync: Boolean(data.needs_vercel_sync), + }) + setForm({ + admin: { jwt_expire_hours: Number(data.admin?.jwt_expire_hours || 24) }, + runtime: { + account_max_inflight: Number(data.runtime?.account_max_inflight || 2), + account_max_queue: Number(data.runtime?.account_max_queue || 10), + global_max_inflight: Number(data.runtime?.global_max_inflight || 10), + }, + toolcall: { + mode: data.toolcall?.mode || 'feature_match', + early_emit_confidence: data.toolcall?.early_emit_confidence || 'high', + }, + responses: { + store_ttl_seconds: Number(data.responses?.store_ttl_seconds || 900), + }, + embeddings: { + provider: data.embeddings?.provider || '', + }, + claude_mapping_text: JSON.stringify(data.claude_mapping || {}, null, 2), + model_aliases_text: JSON.stringify(data.model_aliases || {}, null, 2), + }) + } catch (e) { + onMessage('error', t('settings.loadFailed')) + // eslint-disable-next-line no-console + console.error(e) + } finally { + setLoading(false) + } + }, [apiFetch, onMessage, t]) + + useEffect(() => { + loadSettings() + }, [loadSettings]) + + const saveSettings = async () => { + let claudeMapping = {} + let modelAliases = {} + try { + claudeMapping = parseJSONMap(form.claude_mapping_text, 'claude_mapping') + modelAliases = parseJSONMap(form.model_aliases_text, 'model_aliases') + } catch (e) { + onMessage('error', e.message) + return + } + + const payload = { + admin: { jwt_expire_hours: Number(form.admin.jwt_expire_hours) }, + runtime: { + account_max_inflight: Number(form.runtime.account_max_inflight), + account_max_queue: Number(form.runtime.account_max_queue), + global_max_inflight: Number(form.runtime.global_max_inflight), + }, + toolcall: { + mode: String(form.toolcall.mode || '').trim(), + early_emit_confidence: String(form.toolcall.early_emit_confidence || '').trim(), + }, + responses: { store_ttl_seconds: Number(form.responses.store_ttl_seconds) }, + embeddings: { provider: String(form.embeddings.provider || '').trim() }, + claude_mapping: claudeMapping, + model_aliases: modelAliases, + } + + setSaving(true) + try { + const res = await apiFetch('/admin/settings', { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload), + }) + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.saveFailed')) + return + } + onMessage('success', t('settings.saveSuccess')) + if (typeof onRefresh === 'function') { + onRefresh() + } + await loadSettings() + } catch (e) { + onMessage('error', t('settings.saveFailed')) + // eslint-disable-next-line no-console + console.error(e) + } finally { + setSaving(false) + } + } + + const updatePassword = async () => { + if (String(newPassword || '').trim().length < 4) { + onMessage('error', t('settings.passwordTooShort')) + return + } + setChangingPassword(true) + try { + const res = await apiFetch('/admin/settings/password', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ new_password: newPassword.trim() }), + }) + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.passwordUpdateFailed')) + return + } + onMessage('success', t('settings.passwordUpdated')) + setNewPassword('') + if (typeof onForceLogout === 'function') { + onForceLogout() + } + } catch (e) { + onMessage('error', t('settings.passwordUpdateFailed')) + } finally { + setChangingPassword(false) + } + } + + const loadExportData = async () => { + try { + const res = await apiFetch('/admin/config/export') + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.exportFailed')) + return + } + setExportData(data) + onMessage('success', t('settings.exportLoaded')) + } catch (e) { + onMessage('error', t('settings.exportFailed')) + } + } + + const doImport = async () => { + if (!String(importText || '').trim()) { + onMessage('error', t('settings.importEmpty')) + return + } + let parsed + try { + parsed = JSON.parse(importText) + } catch (_e) { + onMessage('error', t('settings.importInvalidJson')) + return + } + setImporting(true) + try { + const res = await apiFetch(`/admin/config/import?mode=${encodeURIComponent(importMode)}`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ config: parsed, mode: importMode }), + }) + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.importFailed')) + return + } + onMessage('success', t('settings.importSuccess', { mode: importMode })) + if (typeof onRefresh === 'function') { + onRefresh() + } + await loadSettings() + } catch (e) { + onMessage('error', t('settings.importFailed')) + } finally { + setImporting(false) + } + } + + const syncHintVisible = useMemo(() => settingsMeta.env_backed || settingsMeta.needs_vercel_sync, [settingsMeta.env_backed, settingsMeta.needs_vercel_sync]) + + return ( +
+ {settingsMeta.default_password_warning && ( +
+ + {t('settings.defaultPasswordWarning')} +
+ )} + {syncHintVisible && ( +
+ + {t('settings.vercelSyncHint')} +
+ )} + +
+

{t('settings.securityTitle')}

+
+ + +
+
+ +
+

{t('settings.runtimeTitle')}

+
+ + + +
+
+ +
+

{t('settings.behaviorTitle')}

+
+ + + + +
+
+ +
+

{t('settings.modelTitle')}

+
+