diff --git a/.env.example b/.env.example index 21a4d2a..d63f133 100644 --- a/.env.example +++ b/.env.example @@ -52,6 +52,9 @@ DS2API_ADMIN_KEY=admin # Option C: Base64 encoded JSON (recommended for Vercel env var) # DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0= +# +# Generate from local config.json: +# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" # --------------------------------------------------------------- # Paths (optional) diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 4ed0cd0..67689cc 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -73,16 +73,6 @@ jobs: rm -rf "${STAGE}" done - (cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt) - - - name: Upload Release Assets - uses: softprops/action-gh-release@v2 - with: - files: | - dist/*.tar.gz - dist/*.zip - dist/sha256sums.txt - - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -96,11 +86,19 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Extract Docker metadata - id: meta + id: meta_release uses: docker/metadata-action@v5 with: - images: ghcr.io/${{ github.repository }} + images: | + ghcr.io/${{ github.repository }} + cjackhwang/ds2api tags: | type=raw,value=${{ github.event.release.tag_name }} type=raw,value=latest @@ -112,5 +110,36 @@ jobs: file: ./Dockerfile push: true platforms: linux/amd64,linux/arm64 - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + tags: ${{ steps.meta_release.outputs.tags }} + labels: ${{ steps.meta_release.outputs.labels }} + + - name: Export Docker image archives for release assets + run: | + set -euo pipefail + TAG="${{ github.event.release.tag_name }}" + + docker buildx build \ + --platform linux/amd64 \ + --output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \ + . + + docker buildx build \ + --platform linux/arm64 \ + --output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \ + . + + gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar" + gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar" + + - name: Generate checksums + run: | + set -euo pipefail + (cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt) + + - name: Upload Release Assets + uses: softprops/action-gh-release@v2 + with: + files: | + dist/*.tar.gz + dist/*.zip + dist/sha256sums.txt diff --git a/.gitignore b/.gitignore index 5f776e2..422c203 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,9 @@ ds2api-tests htmlcov/ .pytest_cache/ .tox/ +*.coverprofile +coverage*.out +cover/ # Misc *.pyc diff --git a/API.en.md b/API.en.md index e570dee..ef1a6f3 100644 --- a/API.en.md +++ b/API.en.md @@ -9,6 +9,7 @@ This document describes the actual behavior of the current Go codebase. ## Table of Contents - [Basics](#basics) +- [Configuration Best Practice](#configuration-best-practice) - [Authentication](#authentication) - [Route Index](#route-index) - [Health Endpoints](#health-endpoints) @@ -27,7 +28,29 @@ This document describes the actual behavior of the current Go codebase. | Base URL | `http://localhost:5001` or your deployment domain | | Default Content-Type | `application/json` | | Health probes | `GET /healthz`, `GET /readyz` | -| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) | +| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) | + +--- + +## Configuration Best Practice + +Use `config.json` as the single source of truth: + +```bash +cp config.example.json config.json +# Edit config.json (keys/accounts) +``` + +Use it per deployment mode: + +- Local run: read `config.json` directly +- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON` + +```bash +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +``` + +For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page. --- @@ -66,7 +89,11 @@ Two header formats accepted: | GET | `/healthz` | None | Liveness probe | | GET | `/readyz` | None | Readiness probe | | GET | `/v1/models` | None | OpenAI model list | +| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) | | POST | `/v1/chat/completions` | Business | OpenAI chat completions | +| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) | +| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) | +| POST | `/v1/embeddings` | Business | OpenAI Embeddings API | | GET | `/anthropic/v1/models` | None | Claude model list | | POST | `/anthropic/v1/messages` | Business | Claude messages | | POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting | @@ -127,6 +154,15 @@ No auth required. Returns supported models. } ``` +### Model Alias Resolution + +For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy: + +1. Match DeepSeek native model IDs first. +2. Then match exact keys in `model_aliases`. +3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.). +4. If still unmatched, return `invalid_request_error`. + ### `POST /v1/chat/completions` **Headers**: @@ -140,7 +176,7 @@ Content-Type: application/json | Field | Type | Required | Notes | | --- | --- | --- | --- | -| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` | +| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) | | `messages` | array | ✅ | OpenAI-style messages | | `stream` | boolean | ❌ | Default `false` | | `tools` | array | ❌ | Function calling schema | @@ -230,7 +266,63 @@ When `tools` is present, DS2API performs anti-leak handling: } ``` -**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once. +**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`. + +--- + +### `GET /v1/models/{id}` + +No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model. + +### `POST /v1/responses` + +OpenAI Responses-style endpoint, accepting either `input` or `messages`. + +| Field | Type | Required | Notes | +| --- | --- | --- | --- | +| `model` | string | ✅ | Supports native models + alias mapping | +| `input` | string/array/object | ❌ | One of `input` or `messages` is required | +| `messages` | array | ❌ | One of `input` or `messages` is required | +| `instructions` | string | ❌ | Prepended as a system message | +| `stream` | boolean | ❌ | Default `false` | +| `tools` | array | ❌ | Same tool detection/translation policy as chat | + +**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache. + +**Stream (SSE)**: minimal event sequence: + +```text +event: response.created +data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} + +event: response.output_tool_call.delta +data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]} + +event: response.completed +data: {"type":"response.completed","response":{...}} + +data: [DONE] +``` + +### `GET /v1/responses/{response_id}` + +Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read). + +> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`). + +### `POST /v1/embeddings` + +Business auth required. Returns OpenAI-compatible embeddings shape. + +| Field | Type | Required | Notes | +| --- | --- | --- | --- | +| `model` | string | ✅ | Supports native models + alias mapping | +| `input` | string/array | ✅ | Supports string, string array, token array | + +> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501. --- @@ -249,7 +341,10 @@ No auth required. {"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"} - ] + ], + "first_id": "claude-opus-4-6", + "last_id": "claude-instant-1.0", + "has_more": false } ``` @@ -265,13 +360,15 @@ Content-Type: application/json anthropic-version: 2023-06-01 ``` +> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent. + **Request body**: | Field | Type | Required | Notes | | --- | --- | --- | --- | | `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs | | `messages` | array | ✅ | Claude-style messages | -| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge | +| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge | | `stream` | boolean | ❌ | Default `false` | | `system` | string | ❌ | Optional system prompt | | `tools` | array | ❌ | Claude tool schema | @@ -416,6 +513,7 @@ Returns sanitized config. "keys": ["k1", "k2"], "accounts": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -476,6 +574,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. { "items": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -500,7 +599,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. ### `DELETE /admin/accounts/{identifier}` -`identifier` is email or mobile. +`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:`). **Response**: `{"success": true, "total_accounts": 5}` @@ -530,7 +629,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. | Field | Required | Notes | | --- | --- | --- | -| `identifier` | ✅ | email or mobile | +| `identifier` | ✅ | email / mobile / token-only synthetic id | | `model` | ❌ | default `deepseek-chat` | | `message` | ❌ | if empty, only session creation is tested | @@ -659,13 +758,20 @@ Or manual deploy required: ## Error Payloads -Error formats vary by module: +Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope: -| Module | Format | -| --- | --- | -| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` | -| Claude routes | `{"error": {"type": "...", "message": "..."}}` | -| Admin routes | `{"detail": "..."}` | +```json +{ + "error": { + "message": "...", + "type": "invalid_request_error", + "code": "invalid_request", + "param": null + } +} +``` + +Admin routes keep `{"detail":"..."}`. Clients should handle HTTP status code plus `error` / `detail` fields. @@ -707,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### OpenAI Responses (Stream) + +```bash +curl http://localhost:5001/v1/responses \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5-codex", + "input": "Write a hello world in golang", + "stream": true + }' +``` + +### OpenAI Embeddings + +```bash +curl http://localhost:5001/v1/embeddings \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "input": ["first text", "second text"] + }' +``` + ### OpenAI with Search ```bash diff --git a/API.md b/API.md index 6be7f65..3770924 100644 --- a/API.md +++ b/API.md @@ -9,6 +9,7 @@ ## 目录 - [基础信息](#基础信息) +- [配置最佳实践](#配置最佳实践) - [鉴权规则](#鉴权规则) - [路由总览](#路由总览) - [健康检查](#健康检查) @@ -27,7 +28,29 @@ | Base URL | `http://localhost:5001` 或你的部署域名 | | 默认 Content-Type | `application/json` | | 健康检查 | `GET /healthz`、`GET /readyz` | -| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`) | +| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) | + +--- + +## 配置最佳实践 + +推荐把 `config.json` 作为唯一配置源: + +```bash +cp config.example.json config.json +# 编辑 config.json(keys/accounts) +``` + +按部署方式使用: + +- 本地运行:直接读取 `config.json` +- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON` + +```bash +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +``` + +Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。 --- @@ -66,7 +89,11 @@ | GET | `/healthz` | 无 | 存活探针 | | GET | `/readyz` | 无 | 就绪探针 | | GET | `/v1/models` | 无 | OpenAI 模型列表 | +| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) | | POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 | +| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) | +| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) | +| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 | | GET | `/anthropic/v1/models` | 无 | Claude 模型列表 | | POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 | | POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 | @@ -127,6 +154,15 @@ } ``` +### 模型 alias 解析策略 + +对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”: + +1. 先匹配 DeepSeek 原生模型。 +2. 再匹配 `model_aliases` 精确映射。 +3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。 +4. 仍未命中则返回 `invalid_request_error`。 + ### `POST /v1/chat/completions` **请求头**: @@ -140,7 +176,7 @@ Content-Type: application/json | 字段 | 类型 | 必填 | 说明 | | --- | --- | --- | --- | -| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` | +| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) | | `messages` | array | ✅ | OpenAI 风格消息数组 | | `stream` | boolean | ❌ | 默认 `false` | | `tools` | array | ❌ | Function Calling 定义 | @@ -230,7 +266,63 @@ data: [DONE] } ``` -**流式**:先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call 带 `index`);否则一次性输出普通文本。 +**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。 + +--- + +### `GET /v1/models/{id}` + +无需鉴权。入参支持 alias(例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。 + +### `POST /v1/responses` + +OpenAI Responses 风格接口,兼容 `input` 或 `messages`。 + +| 字段 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `model` | string | ✅ | 支持原生模型 + alias 自动映射 | +| `input` | string/array/object | ❌ | 与 `messages` 二选一 | +| `messages` | array | ❌ | 与 `input` 二选一 | +| `instructions` | string | ❌ | 自动前置为 system 消息 | +| `stream` | boolean | ❌ | 默认 `false` | +| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 | + +**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。 + +**流式响应(SSE)**:最小事件序列如下。 + +```text +event: response.created +data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} + +event: response.output_tool_call.delta +data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]} + +event: response.completed +data: {"type":"response.completed","response":{...}} + +data: [DONE] +``` + +### `GET /v1/responses/{response_id}` + +需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。 + +> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。 + +### `POST /v1/embeddings` + +需要业务鉴权。返回 OpenAI Embeddings 兼容结构。 + +| 字段 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `model` | string | ✅ | 支持原生模型 + alias 自动映射 | +| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 | + +> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。 --- @@ -249,7 +341,10 @@ data: [DONE] {"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"} - ] + ], + "first_id": "claude-opus-4-6", + "last_id": "claude-instant-1.0", + "has_more": false } ``` @@ -265,13 +360,15 @@ Content-Type: application/json anthropic-version: 2023-06-01 ``` +> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。 + **请求体**: | 字段 | 类型 | 必填 | 说明 | | --- | --- | --- | --- | | `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID | | `messages` | array | ✅ | Claude 风格消息数组 | -| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 | +| `max_tokens` | number | ❌ | 缺省自动补 `8192`;当前实现不会硬性截断上游输出 | | `stream` | boolean | ❌ | 默认 `false` | | `system` | string | ❌ | 可选系统提示 | | `tools` | array | ❌ | Claude tool 定义 | @@ -416,6 +513,7 @@ data: {"type":"message_stop"} "keys": ["k1", "k2"], "accounts": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -476,6 +574,7 @@ data: {"type":"message_stop"} { "items": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -500,7 +599,7 @@ data: {"type":"message_stop"} ### `DELETE /admin/accounts/{identifier}` -`identifier` 为 email 或 mobile。 +`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:`)。 **响应**:`{"success": true, "total_accounts": 5}` @@ -530,7 +629,7 @@ data: {"type":"message_stop"} | 字段 | 必填 | 说明 | | --- | --- | --- | -| `identifier` | ✅ | email 或 mobile | +| `identifier` | ✅ | email / mobile / token-only 合成标识 | | `model` | ❌ | 默认 `deepseek-chat` | | `message` | ❌ | 空字符串时仅测试会话创建 | @@ -659,13 +758,20 @@ data: {"type":"message_stop"} ## 错误响应格式 -不同模块的错误格式略有差异: +兼容路由(`/v1/*`、`/anthropic/*`)统一使用以下结构: -| 模块 | 格式 | -| --- | --- | -| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` | -| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` | -| Admin 接口 | `{"detail": "..."}` | +```json +{ + "error": { + "message": "...", + "type": "invalid_request_error", + "code": "invalid_request", + "param": null + } +} +``` + +Admin 接口保持 `{"detail":"..."}`。 建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。 @@ -707,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### OpenAI Responses(流式) + +```bash +curl http://localhost:5001/v1/responses \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5-codex", + "input": "写一个 golang 的 hello world", + "stream": true + }' +``` + +### OpenAI Embeddings + +```bash +curl http://localhost:5001/v1/embeddings \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "input": ["第一段文本", "第二段文本"] + }' +``` + ### OpenAI 带搜索 ```bash diff --git a/DEPLOY.en.md b/DEPLOY.en.md index b7caf8c..8a62c98 100644 --- a/DEPLOY.en.md +++ b/DEPLOY.en.md @@ -33,6 +33,17 @@ Config source (choose one): - **File**: `config.json` (recommended for local/Docker) - **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64) +Unified recommendation (best practice): + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Use `config.json` as the single source of truth: +- Local run: read `config.json` directly +- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it + --- ## 1. Local Run @@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api ### 2.1 Basic Steps ```bash -# Copy and edit environment +# Copy env template cp .env.example .env -# Edit .env, at minimum set: + +# Generate single-line Base64 from config.json +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# Edit .env and set: # DS2API_ADMIN_KEY=your-admin-key -# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]} +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} # Start docker-compose up -d @@ -167,15 +182,49 @@ If container logs look normal but the admin panel is unreachable, check these fi 1. **Fork** the repo to your GitHub account 2. **Import** the project on Vercel -3. **Set environment variables** (at minimum): +3. **Set environment variables** (minimum required: one variable): | Variable | Description | | --- | --- | | `DS2API_ADMIN_KEY` | Admin key (required) | - | `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) | + | `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) | 4. **Deploy** +### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes) + +If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page. + +Recommended: in repo root, copy the template first and fill your real accounts: + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it: + +```bash +# Run in repo root +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +echo "$DS2API_CONFIG_JSON" +``` + +If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables: + +```text +DS2API_ADMIN_KEY=replace-with-a-strong-secret +DS2API_CONFIG_JSON= +``` + +Optional but recommended (for WebUI one-click Vercel sync): + +```text +VERCEL_TOKEN=your-vercel-token +VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx +VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts +``` + ### 3.2 Optional Environment Variables | Variable | Description | Default | diff --git a/DEPLOY.md b/DEPLOY.md index b7fbf9a..e5b0630 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -33,6 +33,17 @@ - **文件方式**:`config.json`(推荐本地/Docker 使用) - **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码) +统一建议(最优实践): + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +建议把 `config.json` 作为唯一配置源: +- 本地运行:直接读 `config.json` +- Docker / Vercel:从 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量 + --- ## 一、本地运行 @@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api ### 2.1 基本步骤 ```bash -# 复制并编辑环境变量 +# 复制环境变量模板 cp .env.example .env -# 编辑 .env,至少设置: + +# 从 config.json 生成单行 Base64 +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 编辑 .env(请改成你的强密码),设置: # DS2API_ADMIN_KEY=your-admin-key -# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]} +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} # 启动 docker-compose up -d @@ -167,15 +182,49 @@ healthcheck: 1. **Fork 仓库**到你的 GitHub 账号 2. **在 Vercel 上导入项目** -3. **配置环境变量**(至少设置以下两项): +3. **配置环境变量**(最少只需设置以下一项): | 变量 | 说明 | | --- | --- | | `DS2API_ADMIN_KEY` | 管理密钥(必填) | - | `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(必填) | + | `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(可选,建议) | 4. **部署** +### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错) + +如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置,再在「Vercel 同步」里写回环境变量。 + +建议先在仓库目录复制示例配置,再按实际账号填写: + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +不要在 Vercel 面板里手写复杂 JSON,建议本地生成 Base64 后粘贴: + +```bash +# 在仓库根目录执行 +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +echo "$DS2API_CONFIG_JSON" +``` + +如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置: + +```text +DS2API_ADMIN_KEY=请替换为强密码 +DS2API_CONFIG_JSON=上一步生成的一整行 Base64 +``` + +可选但推荐(用于 WebUI 一键同步 Vercel 配置): + +```text +VERCEL_TOKEN=你的 Vercel Token +VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx +VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空 +``` + ### 3.2 可选环境变量 | 变量 | 说明 | 默认值 | diff --git a/README.MD b/README.MD index b438b75..261e34a 100644 --- a/README.MD +++ b/README.MD @@ -54,16 +54,27 @@ flowchart LR | 能力 | 说明 | | --- | --- | -| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) | +| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` | | Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` | | 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 | | 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 | | DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 | -| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 | +| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 | | Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 | | WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) | | 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) | +## 平台兼容矩阵 + +| 级别 | 平台 | 当前状态 | +| --- | --- | --- | +| P0 | Codex CLI/SDK(`wire_api=chat` / `wire_api=responses`) | ✅ | +| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ | +| P0 | Vercel AI SDK(openai-compatible) | ✅ | +| P0 | Anthropic SDK(messages) | ✅ | +| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ | +| P2 | MCP 独立桥接层 | 规划中 | + ## 模型支持 ### OpenAI 接口 @@ -88,6 +99,19 @@ flowchart LR ## 快速开始 +### 通用第一步(所有部署方式) + +把 `config.json` 作为唯一配置源(推荐做法): + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +后续部署建议: +- 本地运行:直接读取 `config.json` +- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量 + ### 方式一:本地运行 **前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时) @@ -112,14 +136,20 @@ go run ./cmd/ds2api ### 方式二:Docker 运行 ```bash -# 1. 配置环境变量 +# 1. 准备环境变量文件 cp .env.example .env -# 编辑 .env -# 2. 启动 +# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64) +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 3. 编辑 .env,设置: +# DS2API_ADMIN_KEY=请替换为强密码 +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} + +# 4. 启动 docker-compose up -d -# 3. 查看日志 +# 5. 查看日志 docker-compose logs -f ``` @@ -129,9 +159,22 @@ docker-compose logs -f 1. Fork 仓库到自己的 GitHub 2. 在 Vercel 上导入项目 -3. 配置环境变量(至少设置 `DS2API_ADMIN_KEY` 和 `DS2API_CONFIG_JSON`) +3. 配置环境变量(最少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`) 4. 部署 +建议先在仓库目录复制模板并填写: + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +推荐:先本地把 `config.json` 转成 Base64,再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误: + +```bash +base64 < config.json | tr -d '\n' +``` + > **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。 详细部署说明请参阅 [部署指南](DEPLOY.md)。 @@ -164,6 +207,7 @@ cp opencode.json.example opencode.json 3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。 > 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。 +> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。 ## 配置说明 @@ -184,6 +228,24 @@ cp opencode.json.example opencode.json "token": "" } ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" @@ -194,6 +256,11 @@ cp opencode.json.example opencode.json - `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer ` 鉴权 - `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录 - `token`:留空则首次请求时自动登录获取;也可预填已有 token +- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射 +- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出) +- `toolcall`:固定采用特征匹配 + 高置信早发策略 +- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL +- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`) - `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型 ### 环境变量 @@ -249,10 +316,10 @@ cp opencode.json.example opencode.json 当请求中带 `tools` 时,DS2API 会做防泄漏处理: -1. `stream=true` 时先**缓冲**正文片段 -2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本 -3. 若最终不是工具调用 → 一次性输出普通文本 -4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式 +1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发) +2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls` +3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content` +4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出 ## 项目结构 diff --git a/README.en.md b/README.en.md index bbad73b..5d2f326 100644 --- a/README.en.md +++ b/README.en.md @@ -54,16 +54,27 @@ flowchart LR | Capability | Details | | --- | --- | -| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) | +| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` | | Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` | | Multi-account rotation | Auto token refresh, email/mobile dual login | | Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency | | DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency | -| Tool Calling | Anti-leak handling: auto buffer, detect, structured output | +| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output | | Admin API | Config management, account testing/batch test, import/export, Vercel sync | | WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) | | Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) | +## Platform Compatibility Matrix + +| Tier | Platform | Status | +| --- | --- | --- | +| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ | +| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ | +| P0 | Vercel AI SDK (openai-compatible) | ✅ | +| P0 | Anthropic SDK (messages) | ✅ | +| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ | +| P2 | MCP standalone bridge | Planned | + ## Model Support ### OpenAI Endpoint @@ -88,6 +99,19 @@ In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4 ## Quick Start +### Universal First Step (all deployment modes) + +Use `config.json` as the single source of truth (recommended): + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Recommended per deployment mode: +- Local run: read `config.json` directly +- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON` + ### Option 1: Local Run **Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally) @@ -112,14 +136,20 @@ Default URL: `http://localhost:5001` ### Option 2: Docker ```bash -# 1. Configure environment +# 1. Prepare env file cp .env.example .env -# Edit .env -# 2. Start +# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64) +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 3. Edit .env and set: +# DS2API_ADMIN_KEY=replace-with-a-strong-secret +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} + +# 4. Start docker-compose up -d -# 3. View logs +# 5. View logs docker-compose logs -f ``` @@ -129,9 +159,22 @@ Rebuild after updates: `docker-compose up -d --build` 1. Fork this repo to your GitHub account 2. Import the project on Vercel -3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`) +3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`) 4. Deploy +Recommended first step in repo root: + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes: + +```bash +base64 < config.json | tr -d '\n' +``` + > **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling. For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md). @@ -164,6 +207,7 @@ cp opencode.json.example opencode.json 3. Start OpenCode CLI in the project directory (run `opencode` using your installed method). > Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example. +> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths. ## Configuration @@ -184,6 +228,24 @@ cp opencode.json.example opencode.json "token": "" } ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" @@ -194,6 +256,11 @@ cp opencode.json.example opencode.json - `keys`: API access keys; clients authenticate via `Authorization: Bearer ` - `accounts`: DeepSeek account list, supports `email` or `mobile` login - `token`: Leave empty for auto-login on first request; or pre-fill an existing token +- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models +- `compat.wide_input_strict_output`: Keep `true` (current default policy) +- `toolcall`: Fixed to feature matching + high-confidence early emit +- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}` +- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in) - `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models ### Environment Variables @@ -249,10 +316,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency) When `tools` is present in the request, DS2API performs anti-leak handling: -1. With `stream=true`, DS2API **buffers** text deltas first -2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked -3. If no tool call → buffered text is emitted at once -4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads +1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored) +2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately +3. Confirmed toolcall JSON fragments are never leaked into `delta.content` +4. Natural language before/after toolcalls keeps original order, with incremental argument output supported ## Project Structure diff --git a/api/chat-stream.js b/api/chat-stream.js index aa92b17..1a8e896 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -1,5 +1,7 @@ 'use strict'; +const crypto = require('crypto'); + const { extractToolNames, createToolSieveState, @@ -83,23 +85,57 @@ module.exports = async function handler(req, res) { const finalPrompt = asString(prep.body.final_prompt); const thinkingEnabled = toBool(prep.body.thinking_enabled); const searchEnabled = toBool(prep.body.search_enabled); - const toolNames = extractToolNames(payload.tools); + const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); + const toolNames = toolPolicy.toolNames; if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { writeOpenAIError(res, 500, 'invalid vercel prepare response'); return; } const releaseLease = createLeaseReleaser(req, leaseID); + const upstreamController = new AbortController(); + let clientClosed = false; + let reader = null; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + upstreamController.abort(); + if (reader && typeof reader.cancel === 'function') { + Promise.resolve(reader.cancel()).catch(() => {}); + } + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); try { - const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { - method: 'POST', - headers: { - ...BASE_HEADERS, - authorization: `Bearer ${deepseekToken}`, - 'x-ds-pow-response': powHeader, - }, - body: JSON.stringify(completionPayload), - }); + let completionRes; + try { + completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { + method: 'POST', + headers: { + ...BASE_HEADERS, + authorization: `Bearer ${deepseekToken}`, + 'x-ds-pow-response': powHeader, + }, + body: JSON.stringify(completionPayload), + signal: upstreamController.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + return; + } + throw err; + } + if (clientClosed) { + return; + } if (!completionRes.ok || !completionRes.body) { const detail = await safeReadText(completionRes); @@ -121,15 +157,20 @@ module.exports = async function handler(req, res) { let currentType = thinkingEnabled ? 'thinking' : 'text'; let thinkingText = ''; let outputText = ''; - const toolSieveEnabled = toolNames.length > 0; + const toolSieveEnabled = toolPolicy.toolSieveEnabled; + const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas; const toolSieveState = createToolSieveState(); let toolCallsEmitted = false; + const streamToolCallIDs = new Map(); const decoder = new TextDecoder(); - const reader = completionRes.body.getReader(); + reader = completionRes.body.getReader(); let buffered = ''; let ended = false; const sendFrame = (obj) => { + if (clientClosed || res.writableEnded || res.destroyed) { + return; + } res.write(`data: ${JSON.stringify(obj)}\n\n`); if (typeof res.flush === 'function') { res.flush(); @@ -156,6 +197,10 @@ module.exports = async function handler(req, res) { return; } ended = true; + if (clientClosed || res.writableEnded || res.destroyed) { + await releaseLease(); + return; + } const detected = parseToolCalls(outputText, toolNames); if (detected.length > 0 && !toolCallsEmitted) { toolCallsEmitted = true; @@ -179,14 +224,22 @@ module.exports = async function handler(req, res) { choices: [{ delta: {}, index: 0, finish_reason: reason }], usage: buildUsage(finalPrompt, thinkingText, outputText), }); - res.write('data: [DONE]\n\n'); + if (!res.writableEnded && !res.destroyed) { + res.write('data: [DONE]\n\n'); + } await releaseLease(); - res.end(); + if (!res.writableEnded && !res.destroyed) { + res.end(); + } }; try { // eslint-disable-next-line no-constant-condition while (true) { + if (clientClosed) { + await finish('stop'); + return; + } const { value, done } = await reader.read(); if (done) { break; @@ -245,6 +298,14 @@ module.exports = async function handler(req, res) { } const events = processToolSieveChunk(toolSieveState, p.text, toolNames); for (const evt of events) { + if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + if (!emitEarlyToolDeltas) { + continue; + } + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); + continue; + } if (evt.type === 'tool_calls') { toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); @@ -259,10 +320,16 @@ module.exports = async function handler(req, res) { } } await finish('stop'); - } catch (_err) { + } catch (err) { + if (clientClosed || isAbortError(err)) { + await finish('stop'); + return; + } await finish('stop'); } } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); await releaseLease(); } }; @@ -345,6 +412,37 @@ function relayPreparedFailure(res, prep) { writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); } +function resolveToolcallPolicy(prepBody, payloadTools) { + const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); + const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); + const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); + return { + toolNames, + toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled, + emitEarlyToolDeltas, + }; +} + +function normalizePreparedToolNames(v) { + if (!Array.isArray(v) || v.length === 0) { + return []; + } + const out = []; + for (const item of v) { + const name = asString(item); + if (!name) { + continue; + } + out.push(name); + } + return out; +} + +function boolDefaultTrue(v) { + return v !== false; +} + async function safeReadText(resp) { if (!resp) { return ''; @@ -656,6 +754,55 @@ function buildUsage(prompt, thinking, output) { }; } +function formatIncrementalToolCallDeltas(deltas, idStore) { + if (!Array.isArray(deltas) || deltas.length === 0) { + return []; + } + const out = []; + for (const d of deltas) { + if (!d || typeof d !== 'object') { + continue; + } + const index = Number.isInteger(d.index) ? d.index : 0; + const id = ensureStreamToolCallID(idStore, index); + const item = { + index, + id, + type: 'function', + }; + const fn = {}; + if (asString(d.name)) { + fn.name = asString(d.name); + } + if (typeof d.arguments === 'string' && d.arguments !== '') { + fn.arguments = d.arguments; + } + if (Object.keys(fn).length > 0) { + item.function = fn; + } + out.push(item); + } + return out; +} + +function ensureStreamToolCallID(idStore, index) { + const key = Number.isInteger(index) ? index : 0; + const existing = idStore.get(key); + if (existing) { + return existing; + } + const next = `call_${newCallID()}`; + idStore.set(key, next); + return next; +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + function estimateTokens(text) { const t = asString(text); if (!t) { @@ -667,44 +814,92 @@ function estimateTokens(text) { async function proxyToGo(req, res, rawBody) { const url = buildInternalGoURL(req); - - const upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withContentType: true }), - body: rawBody, - }); - - res.statusCode = upstream.status; - upstream.headers.forEach((value, key) => { - if (key.toLowerCase() === 'content-length') { + const controller = new AbortController(); + let clientClosed = false; + const markClientClosed = () => { + if (clientClosed) { return; } - res.setHeader(key, value); - }); + clientClosed = true; + controller.abort(); + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); - if (!upstream.body || typeof upstream.body.getReader !== 'function') { - const bytes = Buffer.from(await upstream.arrayBuffer()); - res.end(bytes); - return; - } - - const reader = upstream.body.getReader(); try { - // eslint-disable-next-line no-constant-condition - while (true) { - const { value, done } = await reader.read(); - if (done) { - break; + let upstream; + try { + upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withContentType: true }), + body: rawBody, + signal: controller.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + if (!res.writableEnded) { + res.end(); + } + return; } - if (value && value.length > 0) { - res.write(Buffer.from(value)); - if (typeof res.flush === 'function') { - res.flush(); + throw err; + } + if (clientClosed) { + if (!res.writableEnded) { + res.end(); + } + return; + } + + res.statusCode = upstream.status; + upstream.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (!upstream.body || typeof upstream.body.getReader !== 'function') { + const bytes = Buffer.from(await upstream.arrayBuffer()); + res.end(bytes); + return; + } + + const reader = upstream.body.getReader(); + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + break; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.length > 0) { + res.write(Buffer.from(value)); + if (typeof res.flush === 'function') { + res.flush(); + } } } + if (!res.writableEnded) { + res.end(); + } + } catch (err) { + if (!isAbortError(err) && !res.writableEnded) { + res.end(); + } } - res.end(); - } catch (_err) { + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); if (!res.writableEnded) { res.end(); } @@ -762,9 +957,19 @@ function asString(v) { return String(v).trim(); } +function isAbortError(err) { + if (!err || typeof err !== 'object') { + return false; + } + return err.name === 'AbortError' || err.code === 'ABORT_ERR'; +} + module.exports.__test = { parseChunkForContent, extractContentRecursive, shouldSkipPath, asString, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, }; diff --git a/api/chat-stream.test.js b/api/chat-stream.test.js index b347342..7424df2 100644 --- a/api/chat-stream.test.js +++ b/api/chat-stream.test.js @@ -10,10 +10,50 @@ const { flushToolSieve, } = require('./helpers/stream-tool-sieve'); -const { parseChunkForContent } = handler.__test; +const { + parseChunkForContent, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, +} = handler.__test; test('chat-stream exposes parser test hooks', () => { assert.equal(typeof parseChunkForContent, 'function'); + assert.equal(typeof resolveToolcallPolicy, 'function'); +}); + +test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => { + const policy = resolveToolcallPolicy( + {}, + [{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['read_file']); + assert.equal(policy.toolSieveEnabled, true); + assert.equal(policy.emitEarlyToolDeltas, true); +}); + +test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => { + const policy = resolveToolcallPolicy( + { + tool_names: [' prepped_tool ', '', null], + toolcall_feature_match: false, + toolcall_early_emit_high: false, + }, + [{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['prepped_tool']); + assert.equal(policy.toolSieveEnabled, false); + assert.equal(policy.emitEarlyToolDeltas, false); +}); + +test('normalizePreparedToolNames filters empty values', () => { + assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']); +}); + +test('boolDefaultTrue keeps false only when explicitly false', () => { + assert.equal(boolDefaultTrue(false), false); + assert.equal(boolDefaultTrue(true), true); + assert.equal(boolDefaultTrue(undefined), true); }); test('parseChunkForContent keeps split response/content fragments inside response array', () => { @@ -49,12 +89,13 @@ test('parseChunkForContent + sieve does not leak suspicious prefix in split tool events.push(...flushToolSieve(state, ['read_file'])); const hasToolCalls = events.some((evt) => evt.type === 'tool_calls' && evt.calls && evt.calls.length > 0); + const hasToolDeltas = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas && evt.deltas.length > 0); const leakedText = events .filter((evt) => evt.type === 'text' && evt.text) .map((evt) => evt.text) .join(''); - assert.equal(hasToolCalls, true); + assert.equal(hasToolCalls || hasToolDeltas, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index 3ced63d..44e31cd 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -2,6 +2,8 @@ const crypto = require('crypto'); const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; +const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; +const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; function extractToolNames(tools) { if (!Array.isArray(tools) || tools.length === 0) { @@ -26,9 +28,25 @@ function createToolSieveState() { pending: '', capture: '', capturing: false, + recentTextTail: '', + toolNameSent: false, + toolName: '', + toolArgsStart: -1, + toolArgsSent: -1, + toolArgsString: false, + toolArgsDone: false, }; } +function resetIncrementalToolState(state) { + state.toolNameSent = false; + state.toolName = ''; + state.toolArgsStart = -1; + state.toolArgsSent = -1; + state.toolArgsString = false; + state.toolArgsDone = false; +} + function processToolSieveChunk(state, chunk, toolNames) { if (!state) { return []; @@ -44,13 +62,27 @@ function processToolSieveChunk(state, chunk, toolNames) { state.capture += state.pending; state.pending = ''; } - const consumed = consumeToolCapture(state.capture, toolNames); + const deltas = buildIncrementalToolDeltas(state); + if (deltas.length > 0) { + events.push({ type: 'tool_call_deltas', deltas }); + } + const consumed = consumeToolCapture(state, toolNames); if (!consumed.ready) { + if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { + noteText(state, state.capture); + events.push({ type: 'text', text: state.capture }); + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + continue; + } break; } state.capture = ''; state.capturing = false; + resetIncrementalToolState(state); if (consumed.prefix) { + noteText(state, consumed.prefix); events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { @@ -70,11 +102,13 @@ function processToolSieveChunk(state, chunk, toolNames) { if (start >= 0) { const prefix = state.pending.slice(0, start); if (prefix) { + noteText(state, prefix); events.push({ type: 'text', text: prefix }); } state.capture = state.pending.slice(start); state.pending = ''; state.capturing = true; + resetIncrementalToolState(state); continue; } @@ -83,6 +117,7 @@ function processToolSieveChunk(state, chunk, toolNames) { break; } state.pending = hold; + noteText(state, safe); events.push({ type: 'text', text: safe }); } return events; @@ -94,24 +129,29 @@ function flushToolSieve(state, toolNames) { } const events = processToolSieveChunk(state, '', toolNames); if (state.capturing) { - const consumed = consumeToolCapture(state.capture, toolNames); + const consumed = consumeToolCapture(state, toolNames); if (consumed.ready) { if (consumed.prefix) { + noteText(state, consumed.prefix); events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { events.push({ type: 'tool_calls', calls: consumed.calls }); } if (consumed.suffix) { + noteText(state, consumed.suffix); events.push({ type: 'text', text: consumed.suffix }); } } else if (state.capture) { - // Incomplete captured tool JSON at stream end: suppress raw capture. + noteText(state, state.capture); + events.push({ type: 'text', text: state.capture }); } state.capture = ''; state.capturing = false; + resetIncrementalToolState(state); } if (state.pending) { + noteText(state, state.pending); events.push({ type: 'text', text: state.pending }); state.pending = ''; } @@ -151,15 +191,25 @@ function findToolSegmentStart(s) { return -1; } const lower = s.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return -1; + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + const keyRel = lower.indexOf('tool_calls', offset); + if (keyRel < 0) { + return -1; + } + const keyIdx = keyRel; + const start = s.slice(0, keyIdx).lastIndexOf('{'); + const candidateStart = start >= 0 ? start : keyIdx; + if (!insideCodeFence(s.slice(0, candidateStart))) { + return candidateStart; + } + offset = keyIdx + 'tool_calls'.length; } - const start = s.slice(0, keyIdx).lastIndexOf('{'); - return start >= 0 ? start : keyIdx; } -function consumeToolCapture(captured, toolNames) { +function consumeToolCapture(state, toolNames) { + const captured = state.capture; if (!captured) { return { ready: false, prefix: '', calls: [], suffix: '' }; } @@ -176,25 +226,367 @@ function consumeToolCapture(captured, toolNames) { if (!obj.ok) { return { ready: false, prefix: '', calls: [], suffix: '' }; } - const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames); - if (parsed.length === 0) { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. + const prefixPart = captured.slice(0, start); + const suffixPart = captured.slice(obj.end); + if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { return { ready: true, - prefix: captured.slice(0, start), + prefix: captured, calls: [], - suffix: captured.slice(obj.end), + suffix: '', + }; + } + const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); + if (parsed.length === 0) { + if (state.toolNameSent) { + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + if (state.toolNameSent) { + if (parsed.length > 1) { + return { + ready: true, + prefix: prefixPart, + calls: parsed.slice(1), + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, }; } return { ready: true, - prefix: captured.slice(0, start), + prefix: prefixPart, calls: parsed, - suffix: captured.slice(obj.end), + suffix: suffixPart, }; } +function buildIncrementalToolDeltas(state) { + const captured = state.capture || ''; + if (!captured) { + return []; + } + if (looksLikeToolExampleContext(state.recentTextTail)) { + return []; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return []; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0) { + return []; + } + if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) { + return []; + } + const callStart = findFirstToolCallObjectStart(captured, keyIdx); + if (callStart < 0) { + return []; + } + + const deltas = []; + if (!state.toolName) { + const name = extractToolCallName(captured, callStart); + if (!name) { + return []; + } + state.toolName = name; + } + + if (state.toolArgsStart < 0) { + const args = findToolCallArgsStart(captured, callStart); + if (args) { + state.toolArgsString = Boolean(args.stringMode); + state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start; + state.toolArgsSent = state.toolArgsStart; + } + } + if (!state.toolNameSent) { + if (state.toolArgsStart < 0) { + return []; + } + state.toolNameSent = true; + deltas.push({ index: 0, name: state.toolName }); + } + if (state.toolArgsStart < 0 || state.toolArgsDone) { + return deltas; + } + const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString); + if (!progress) { + return deltas; + } + if (progress.end > state.toolArgsSent) { + deltas.push({ + index: 0, + arguments: captured.slice(state.toolArgsSent, progress.end), + }); + state.toolArgsSent = progress.end; + } + if (progress.complete) { + state.toolArgsDone = true; + } + return deltas; +} + +function findFirstToolCallObjectStart(text, keyIdx) { + const arrStart = findToolCallsArrayStart(text, keyIdx); + if (arrStart < 0) { + return -1; + } + const i = skipSpaces(text, arrStart + 1); + if (i >= text.length || text[i] !== '{') { + return -1; + } + return i; +} + +function findToolCallsArrayStart(text, keyIdx) { + let i = keyIdx + 'tool_calls'.length; + while (i < text.length && text[i] !== ':') { + i += 1; + } + if (i >= text.length) { + return -1; + } + i = skipSpaces(text, i + 1); + if (i >= text.length || text[i] !== '[') { + return -1; + } + return i; +} + +function extractToolCallName(text, callStart) { + let valueStart = findObjectFieldValueStart(text, callStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return ''; + } + valueStart = findObjectFieldValueStart(text, fnStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + return ''; + } + } + const parsed = parseJSONStringLiteral(text, valueStart); + if (!parsed) { + return ''; + } + return parsed.value; +} + +function findToolCallArgsStart(text, callStart) { + const keys = ['input', 'arguments', 'args', 'parameters', 'params']; + let valueStart = findObjectFieldValueStart(text, callStart, keys); + if (valueStart < 0) { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return null; + } + valueStart = findObjectFieldValueStart(text, fnStart, keys); + if (valueStart < 0) { + return null; + } + } + if (valueStart >= text.length) { + return null; + } + const ch = text[valueStart]; + if (ch === '{' || ch === '[') { + return { start: valueStart, stringMode: false }; + } + if (ch === '"') { + return { start: valueStart, stringMode: true }; + } + return null; +} + +function scanToolCallArgsProgress(text, start, stringMode) { + if (start < 0 || start > text.length) { + return null; + } + if (stringMode) { + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { end: i, complete: true }; + } + } + return { end: text.length, complete: false }; + } + if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) { + return null; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{' || ch === '[') { + depth += 1; + continue; + } + if (ch === '}' || ch === ']') { + depth -= 1; + if (depth === 0) { + return { end: i + 1, complete: true }; + } + } + } + return { end: text.length, complete: false }; +} + +function findObjectFieldValueStart(text, objStart, keys) { + if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') { + return -1; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = objStart; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + if (depth === 1) { + const parsed = parseJSONStringLiteral(text, i); + if (!parsed) { + return -1; + } + let j = skipSpaces(text, parsed.end); + if (j >= text.length || text[j] !== ':') { + i = parsed.end - 1; + continue; + } + j = skipSpaces(text, j + 1); + if (j >= text.length) { + return -1; + } + if (keys.includes(parsed.value)) { + return j; + } + i = j - 1; + continue; + } + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + break; + } + } + } + return -1; +} + +function findFunctionObjectStart(text, callStart) { + const valueStart = findObjectFieldValueStart(text, callStart, ['function']); + if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') { + return -1; + } + return valueStart; +} + +function parseJSONStringLiteral(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '"') { + return null; + } + let out = ''; + let escaped = false; + for (let i = start + 1; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + out += ch; + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { value: out, end: i + 1 }; + } + out += ch; + } + return null; +} + +function skipSpaces(text, i) { + let idx = i; + while (idx < text.length) { + const ch = text[idx]; + if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') { + idx += 1; + continue; + } + break; + } + return idx; +} + function extractJSONObjectFrom(text, start) { if (!text || start < 0 || start >= text.length || text[start] !== '{') { return { ok: false, end: 0 }; @@ -240,7 +632,11 @@ function parseToolCalls(text, toolNames) { if (!toStringSafe(text)) { return []; } - const candidates = buildToolCallCandidates(text); + const sanitized = stripFencedCodeBlocks(text); + if (!toStringSafe(sanitized)) { + return []; + } + const candidates = buildToolCallCandidates(sanitized); let parsed = []; for (const c of candidates) { parsed = parseToolCallsPayload(c); @@ -251,26 +647,49 @@ function parseToolCalls(text, toolNames) { if (parsed.length === 0) { return []; } - const allowed = new Set((toolNames || []).filter(Boolean)); - const out = []; - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (allowed.size > 0 && !allowed.has(tc.name)) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); + return filterToolCalls(parsed, toolNames); +} + +function stripFencedCodeBlocks(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return ''; } - if (out.length === 0 && parsed.length > 0) { - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); + return t.replace(/```[\s\S]*?```/g, ' '); +} + +function parseStandaloneToolCalls(text, toolNames) { + const trimmed = toStringSafe(text); + if (!trimmed) { + return []; + } + if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) { + return []; + } + if (looksLikeToolExampleContext(trimmed)) { + return []; + } + const candidates = [trimmed]; + if (trimmed.startsWith('```') && trimmed.endsWith('```')) { + const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); } } - return out; + for (const candidate of candidates) { + const c = toStringSafe(candidate); + if (!c) { + continue; + } + if (!c.startsWith('{') && !c.startsWith('[')) { + continue; + } + const parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + return filterToolCalls(parsed, toolNames); + } + } + return []; } function buildToolCallCandidates(text) { @@ -432,6 +851,66 @@ function parseToolCallInput(v) { return {}; } +function filterToolCalls(parsed, toolNames) { + const allowed = new Set((toolNames || []).filter(Boolean)); + const out = []; + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + if (allowed.size > 0 && !allowed.has(tc.name)) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + if (out.length === 0 && parsed.length > 0) { + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + } + return out; +} + +function noteText(state, text) { + if (!state || !hasMeaningfulText(text)) { + return; + } + state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); +} + +function appendTail(prev, next, max) { + const left = typeof prev === 'string' ? prev : ''; + const right = typeof next === 'string' ? next : ''; + if (!Number.isFinite(max) || max <= 0) { + return ''; + } + const combined = left + right; + if (combined.length <= max) { + return combined; + } + return combined.slice(combined.length - max); +} + +function looksLikeToolExampleContext(text) { + return insideCodeFence(text); +} + +function insideCodeFence(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return false; + } + const ticks = (t.match(/```/g) || []).length; + return ticks % 2 === 1; +} + +function hasMeaningfulText(text) { + return toStringSafe(text) !== ''; +} + function formatOpenAIStreamToolCalls(calls) { if (!Array.isArray(calls) || calls.length === 0) { return []; @@ -473,5 +952,6 @@ module.exports = { processToolSieveChunk, flushToolSieve, parseToolCalls, + parseStandaloneToolCalls, formatOpenAIStreamToolCalls, }; diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js index 47b3100..7f532f1 100644 --- a/api/helpers/stream-tool-sieve.test.js +++ b/api/helpers/stream-tool-sieve.test.js @@ -9,6 +9,7 @@ const { processToolSieveChunk, flushToolSieve, parseToolCalls, + parseStandaloneToolCalls, } = require('./stream-tool-sieve'); function runSieve(chunks, toolNames) { @@ -68,9 +69,22 @@ test('parseToolCalls supports fenced json and function.arguments string payload' '```', ].join('\n'); const calls = parseToolCalls(text, ['read_file']); - assert.equal(calls.length, 1); - assert.equal(calls[0].name, 'read_file'); - assert.deepEqual(calls[0].input, { path: 'README.md' }); + assert.equal(calls.length, 0); +}); + +test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => { + const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。'; + const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'; + const mixedCalls = parseStandaloneToolCalls(mixed, ['read_file']); + const standaloneCalls = parseStandaloneToolCalls(standalone, ['read_file']); + assert.equal(mixedCalls.length, 0); + assert.equal(standaloneCalls.length, 1); +}); + +test('parseStandaloneToolCalls ignores fenced code block tool_call examples', () => { + const fenced = ['```json', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', '```'].join('\n'); + const calls = parseStandaloneToolCalls(fenced, ['read_file']); + assert.equal(calls.length, 0); }); test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => { @@ -84,13 +98,14 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con ); const leakedText = collectText(events); const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); - assert.equal(hasToolCall, true); + const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); + assert.equal(hasToolCall || hasToolDelta, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); assert.equal(leakedText.includes('后置正文C。'), true); }); -test('sieve drops invalid tool json body while preserving surrounding text', () => { +test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => { const events = runSieve( [ '前置正文D。', @@ -104,18 +119,18 @@ test('sieve drops invalid tool json body while preserving surrounding text', () assert.equal(hasToolCall, false); assert.equal(leakedText.includes('前置正文D。'), true); assert.equal(leakedText.includes('后置正文E。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); }); -test('sieve suppresses incomplete captured tool json on stream finalize', () => { +test('sieve flushes incomplete captured tool json as text on stream finalize', () => { const events = runSieve( ['前置正文F。', '{"tool_calls":[{"name":"read_file"'], ['read_file'], ); const leakedText = collectText(events); assert.equal(leakedText.includes('前置正文F。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); - assert.equal(leakedText.includes('{'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + assert.equal(leakedText.includes('{'), true); }); test('sieve keeps plain text intact in tool mode when no tool call appears', () => { @@ -128,3 +143,53 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', () assert.equal(hasToolCall, false); assert.equal(leakedText, '你好,这是普通文本回复。请继续。'); }); + +test('sieve emits incremental tool_call_deltas for split arguments payload', () => { + const state = createToolSieveState(); + const first = processToolSieveChunk( + state, + '{"tool_calls":[{"name":"read_file","input":{"path":"READ', + ['read_file'], + ); + const second = processToolSieveChunk( + state, + 'ME.MD","mode":"head"}}]}', + ['read_file'], + ); + const tail = flushToolSieve(state, ['read_file']); + const events = [...first, ...second, ...tail]; + const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas'); + assert.equal(deltaEvents.length > 0, true); + const merged = deltaEvents.flatMap((evt) => evt.deltas || []); + const hasName = merged.some((d) => d.name === 'read_file'); + const argsJoined = merged + .map((d) => d.arguments || '') + .join(''); + assert.equal(hasName, true); + assert.equal(argsJoined.includes('"path":"README.MD"'), true); + assert.equal(argsJoined.includes('"mode":"head"'), true); +}); + +test('sieve still intercepts tool call after leading plain text without suffix', () => { + const events = runSieve( + ['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'], + ['read_file'], + ); + const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); + const leakedText = collectText(events); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('我将调用工具。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); + +test('sieve intercepts tool call and preserves trailing same-chunk text', () => { + const events = runSieve( + ['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'], + ['read_file'], + ); + const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); + const leakedText = collectText(events); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('然后继续解释。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); diff --git a/config.example.json b/config.example.json index 7614e77..97161f7 100644 --- a/config.example.json +++ b/config.example.json @@ -24,5 +24,27 @@ "password": "your-password-3", "token": "" } - ] -} \ No newline at end of file + ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, + "claude_model_mapping": { + "fast": "deepseek-chat", + "slow": "deepseek-reasoner" + } +} diff --git a/internal/account/pool_edge_test.go b/internal/account/pool_edge_test.go new file mode 100644 index 0000000..6e90823 --- /dev/null +++ b/internal/account/pool_edge_test.go @@ -0,0 +1,249 @@ +package account + +import ( + "context" + "sync" + "testing" + "time" + + "ds2api/internal/config" +) + +// ─── Pool edge cases ───────────────────────────────────────────────── + +func TestPoolEmptyNoAccounts(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + pool := NewPool(config.LoadStore()) + if _, ok := pool.Acquire("", nil); ok { + t.Fatal("expected acquire to fail with no accounts") + } + status := pool.Status() + if total, ok := status["total"].(int); !ok || total != 0 { + t.Fatalf("unexpected total: %#v", status["total"]) + } +} + +func TestPoolReleaseNonExistentAccount(t *testing.T) { + pool := newPoolForTest(t, "2") + pool.Release("nonexistent@example.com") // should not panic +} + +func TestPoolReleaseAlreadyReleased(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected acquire success") + } + pool.Release(acc.Identifier()) + pool.Release(acc.Identifier()) // double release should not panic +} + +func TestPoolAcquireTargetNotFound(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("nonexistent@example.com", nil); ok { + t.Fatal("expected acquire to fail for non-existent target") + } +} + +func TestPoolAcquireWithExclusionList(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true}) + if !ok { + t.Fatal("expected acquire success with exclusion") + } + if acc.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier()) + } + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireAllExcluded(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("", map[string]bool{ + "acc1@example.com": true, + "acc2@example.com": true, + }); ok { + t.Fatal("expected acquire to fail when all accounts excluded") + } +} + +func TestPoolStatusFields(t *testing.T) { + pool := newPoolForTest(t, "2") + status := pool.Status() + + // Check all expected fields are present + for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} { + if _, ok := status[key]; !ok { + t.Fatalf("missing status field: %s", key) + } + } +} + +func TestPoolStatusAccountDetails(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, _ := pool.Acquire("acc1@example.com", nil) + + status := pool.Status() + inUseAccounts, ok := status["in_use_accounts"].([]string) + if !ok { + t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"]) + } + found := false + for _, id := range inUseAccounts { + if id == "acc1@example.com" { + found = true + break + } + } + if !found { + t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts) + } + if status["in_use"] != 1 { + t.Fatalf("expected 1 in_use, got %v", status["in_use"]) + } + + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireWaitContextCancelled(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + // Exhaust the pool + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire to succeed") + } + + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + var waitOK bool + go func() { + defer wg.Done() + _, waitOK = pool.AcquireWait(ctx, "", nil) + }() + + // Wait until queued + waitForWaitingCount(t, pool, 1) + + // Cancel context + cancel() + + wg.Wait() + if waitOK { + t.Fatal("expected acquire to fail after context cancellation") + } + + pool.Release(first.Identifier()) +} + +func TestPoolAcquireWaitTargetAccount(t *testing.T) { + pool := newPoolForTest(t, "1") + // Exhaust acc1 + acc1, ok := pool.Acquire("acc1@example.com", nil) + if !ok { + t.Fatal("expected acquire acc1 success") + } + + // Acquire acc2 directly (should succeed since acc2 is free) + ctx := context.Background() + acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil) + if !ok { + t.Fatal("expected acquire acc2 success via AcquireWait") + } + if acc2.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2, got %q", acc2.Identifier()) + } + + pool.Release(acc1.Identifier()) + pool.Release(acc2.Identifier()) +} + +func TestPoolMaxQueueSizeOverride(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 5 { + t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"]) + } +} + +func TestPoolQueueSizeAliasEnv(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 7 { + t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"]) + } +} + +func TestPoolMultipleAcquireReleaseCycles(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + for i := 0; i < 10; i++ { + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatalf("acquire failed at cycle %d", i) + } + pool.Release(acc.Identifier()) + } +} + +func TestPoolConcurrentAcquireWait(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire success") + } + + const waiters = 3 + results := make(chan bool, waiters) + + for i := 0; i < waiters; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, ok := pool.AcquireWait(ctx, "", nil) + results <- ok + }() + } + + // Wait for all to be queued (only 1 can queue) + time.Sleep(50 * time.Millisecond) + + // Release and allow queued requests to proceed + pool.Release(first.Identifier()) + + successCount := 0 + timeoutCount := 0 + for i := 0; i < waiters; i++ { + select { + case ok := <-results: + if ok { + successCount++ + // Release for next waiter + pool.Release("acc1@example.com") + } else { + timeoutCount++ + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for results") + } + } + + // At least 1 should succeed; 2 may fail due to queue limit + if successCount < 1 { + t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount) + } +} diff --git a/internal/adapter/claude/error_shape_test.go b/internal/adapter/claude/error_shape_test.go new file mode 100644 index 0000000..910fce8 --- /dev/null +++ b/internal/adapter/claude/error_shape_test.go @@ -0,0 +1,35 @@ +package claude + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) { + rec := httptest.NewRecorder() + writeClaudeError(rec, http.StatusUnauthorized, "bad token") + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rec.Code) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + errObj, _ := body["error"].(map[string]any) + if errObj["message"] != "bad token" { + t.Fatalf("unexpected message: %v", errObj["message"]) + } + if errObj["type"] != "invalid_request_error" { + t.Fatalf("unexpected type: %v", errObj["type"]) + } + if errObj["code"] != "authentication_failed" { + t.Fatalf("unexpected code: %v", errObj["code"]) + } + if _, ok := errObj["param"]; !ok { + t.Fatal("expected param field") + } +} + diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index b9ecd27..bac315f 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -43,6 +43,9 @@ func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { } func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { + if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { + r.Header.Set("anthropic-version", "2023-06-01") + } a, err := h.Auth.Determine(r) if err != nil { status := http.StatusUnauthorized @@ -50,132 +53,79 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { if err == auth.ErrNoAccount { status = http.StatusTooManyRequests } - writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}}) + writeClaudeError(w, status, detail) return } defer h.Auth.Release(a) var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}}) + writeClaudeError(w, http.StatusBadRequest, "invalid json") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}}) + norm, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) return } - - normalized := normalizeClaudeMessages(messagesRaw) - payload := cloneMap(req) - payload["messages"] = normalized - toolsRequested, _ := req["tools"].([]any) - if len(toolsRequested) > 0 && !hasSystemMessage(normalized) { - payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...) - } - - dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store) - dsModel, _ := dsPayload["model"].(string) - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) - if !ok { - thinkingEnabled = false - searchEnabled = false - } - finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + stdReq := norm.Standard sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}}) + writeClaudeError(w, http.StatusUnauthorized, "invalid token.") return } pow, err := h.DS.GetPow(r.Context(), a, 3) if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}}) + writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") return } - requestPayload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } + requestPayload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}}) + writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") return } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) + writeClaudeError(w, http.StatusInternalServerError, string(body)) return } - toolNames := extractClaudeToolNames(toolsRequested) - if util.ToBool(req["stream"]) { - h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - result := sse.CollectStream(resp, thinkingEnabled, true) - fullText := result.Text - fullThinking := result.Thinking - detected := util.ParseToolCalls(fullText, toolNames) - content := make([]map[string]any, 0, 4) - if fullThinking != "" { - content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) - } - stopReason := "end_turn" - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - content = append(content, map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), - "name": tc.Name, - "input": tc.Input, - }) - } - } else { - if fullText == "" { - fullText = "抱歉,没有生成有效的响应内容。" - } - content = append(content, map[string]any{"type": "text", "text": fullText}) - } - writeJSON(w, http.StatusOK, map[string]any{ - "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), - "type": "message", - "role": "assistant", - "model": model, - "content": content, - "stop_reason": stopReason, - "stop_sequence": nil, - "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)), - "output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText), - }, - }) + result := sse.CollectStream(resp, stdReq.Thinking, true) + respBody := util.BuildClaudeMessageResponse( + fmt.Sprintf("msg_%d", time.Now().UnixNano()), + stdReq.ResponseModel, + norm.NormalizedMessages, + result.Thinking, + result.Text, + stdReq.ToolNames, + ) + writeJSON(w, http.StatusOK, respBody) } func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { a, err := h.Auth.Determine(r) if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()}) + writeClaudeError(w, http.StatusUnauthorized, err.Error()) return } defer h.Auth.Release(a) var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"}) + writeClaudeError(w, http.StatusBadRequest, "invalid json") return } model, _ := req["model"].(string) messages, _ := req["messages"].([]any) if model == "" || len(messages) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."}) + writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } inputTokens := 0 @@ -206,7 +156,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) + writeClaudeError(w, http.StatusInternalServerError, string(body)) return } @@ -241,6 +191,8 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ "error": map[string]any{ "type": "api_error", "message": msg, + "code": "internal_error", + "param": nil, }, }) } @@ -492,6 +444,28 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ } } +func writeClaudeError(w http.ResponseWriter, status int, message string) { + code := "invalid_request" + switch status { + case http.StatusUnauthorized: + code = "authentication_failed" + case http.StatusTooManyRequests: + code = "rate_limit_exceeded" + case http.StatusNotFound: + code = "not_found" + case http.StatusInternalServerError: + code = "internal_error" + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "type": "invalid_request_error", + "message": message, + "code": code, + "param": nil, + }, + }) +} + func normalizeClaudeMessages(messages []any) []any { out := make([]any, 0, len(messages)) for _, m := range messages { diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go new file mode 100644 index 0000000..73d2fab --- /dev/null +++ b/internal/adapter/claude/handler_util_test.go @@ -0,0 +1,348 @@ +package claude + +import ( + "testing" +) + +// ─── normalizeClaudeMessages ───────────────────────────────────────── + +func TestNormalizeClaudeMessagesSimpleString(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected 1 message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["content"] != "Hello" { + t.Fatalf("expected 'Hello', got %v", m["content"]) + } +} + +func TestNormalizeClaudeMessagesArrayContent(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "line1\nline2" { + t.Fatalf("expected joined text, got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesToolResult(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "tool_result", "content": "tool output"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "tool output" { + t.Fatalf("expected 'tool output', got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { + msgs := []any{"not a map", 42} + got := normalizeClaudeMessages(msgs) + if len(got) != 0 { + t.Fatalf("expected 0 messages for non-map items, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesEmpty(t *testing.T) { + got := normalizeClaudeMessages(nil) + if len(got) != 0 { + t.Fatalf("expected 0, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) { + msgs := []any{ + map[string]any{"role": "assistant", "content": "response"}, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected 'assistant', got %q", m["role"]) + } +} + +func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Hello"}, + map[string]any{"type": "image", "source": "data:..."}, + map[string]any{"type": "text", "text": "World"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "Hello\nWorld" { + t.Fatalf("expected only text parts joined, got %q", m["content"]) + } +} + +// ─── buildClaudeToolPrompt ─────────────────────────────────────────── + +func TestBuildClaudeToolPromptSingleTool(t *testing.T) { + tools := []any{ + map[string]any{ + "name": "search", + "description": "Search the web", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + } + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt") + } + // Should contain tool name and description + if !containsStr(prompt, "search") { + t.Fatalf("expected 'search' in prompt") + } + if !containsStr(prompt, "Search the web") { + t.Fatalf("expected description in prompt") + } + if !containsStr(prompt, "tool_calls") { + t.Fatalf("expected tool_calls instruction in prompt") + } +} + +func TestBuildClaudeToolPromptMultipleTools(t *testing.T) { + tools := []any{ + map[string]any{"name": "tool1", "description": "desc1"}, + map[string]any{"name": "tool2", "description": "desc2"}, + } + prompt := buildClaudeToolPrompt(tools) + if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") { + t.Fatalf("expected both tools in prompt") + } +} + +func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) { + tools := []any{"not a map"} + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt even with invalid tools") + } + // Should still contain the intro and instruction + if !containsStr(prompt, "You are Claude") { + t.Fatalf("expected intro in prompt") + } +} + +// ─── hasSystemMessage ──────────────────────────────────────────────── + +func TestHasSystemMessageTrue(t *testing.T) { + msgs := []any{ + map[string]any{"role": "system", "content": "You are a helper"}, + map[string]any{"role": "user", "content": "Hi"}, + } + if !hasSystemMessage(msgs) { + t.Fatal("expected true") + } +} + +func TestHasSystemMessageFalse(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hi"}, + map[string]any{"role": "assistant", "content": "Hello"}, + } + if hasSystemMessage(msgs) { + t.Fatal("expected false") + } +} + +func TestHasSystemMessageEmpty(t *testing.T) { + if hasSystemMessage(nil) { + t.Fatal("expected false for nil") + } +} + +func TestHasSystemMessageNonMap(t *testing.T) { + msgs := []any{"not a map"} + if hasSystemMessage(msgs) { + t.Fatal("expected false for non-map") + } +} + +// ─── extractClaudeToolNames ────────────────────────────────────────── + +func TestExtractClaudeToolNamesSingle(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "search" { + t.Fatalf("expected [search], got %v", names) + } +} + +func TestExtractClaudeToolNamesMultiple(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + map[string]any{"name": "calculate"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 2 { + t.Fatalf("expected 2 names, got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) { + tools := []any{ + map[string]any{"name": ""}, + map[string]any{"name": "valid"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "valid" { + t.Fatalf("expected [valid], got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) { + tools := []any{"not a map", 42} + names := extractClaudeToolNames(tools) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +func TestExtractClaudeToolNamesNil(t *testing.T) { + names := extractClaudeToolNames(nil) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +// ─── toMessageMaps ─────────────────────────────────────────────────── + +func TestToMessageMapsNormal(t *testing.T) { + input := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1, got %d", len(got)) + } +} + +func TestToMessageMapsNonSlice(t *testing.T) { + got := toMessageMaps("not a slice") + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestToMessageMapsSkipsNonMap(t *testing.T) { + input := []any{"string", map[string]any{"role": "user"}, 42} + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1 map, got %d", len(got)) + } +} + +func TestToMessageMapsNil(t *testing.T) { + got := toMessageMaps(nil) + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +// ─── extractMessageContent ────────────────────────────────────────── + +func TestExtractMessageContentString(t *testing.T) { + if got := extractMessageContent("hello"); got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestExtractMessageContentArray(t *testing.T) { + input := []any{"part1", "part2"} + got := extractMessageContent(input) + if got != "part1\npart2" { + t.Fatalf("expected joined, got %q", got) + } +} + +func TestExtractMessageContentOther(t *testing.T) { + got := extractMessageContent(42) + if got != "42" { + t.Fatalf("expected '42', got %q", got) + } +} + +func TestExtractMessageContentNil(t *testing.T) { + got := extractMessageContent(nil) + if got != "" { + t.Fatalf("expected '', got %q", got) + } +} + +// ─── cloneMap ──────────────────────────────────────────────────────── + +func TestCloneMapBasic(t *testing.T) { + original := map[string]any{"a": 1, "b": "hello"} + clone := cloneMap(original) + original["a"] = 999 + if clone["a"] != 1 { + t.Fatalf("expected 1, got %v", clone["a"]) + } + if clone["b"] != "hello" { + t.Fatalf("expected 'hello', got %v", clone["b"]) + } +} + +func TestCloneMapEmpty(t *testing.T) { + clone := cloneMap(map[string]any{}) + if len(clone) != 0 { + t.Fatalf("expected empty, got %v", clone) + } +} + +func TestCloneMapNested(t *testing.T) { + // cloneMap is shallow, so nested maps share references + inner := map[string]any{"key": "value"} + original := map[string]any{"nested": inner} + clone := cloneMap(original) + // Shallow clone means inner is shared + inner["key"] = "modified" + cloneNested := clone["nested"].(map[string]any) + if cloneNested["key"] != "modified" { + t.Fatal("expected shallow clone to share nested references") + } +} + +// helper +func containsStr(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub)) +} + +func findSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go new file mode 100644 index 0000000..de97c6a --- /dev/null +++ b/internal/adapter/claude/standard_request.go @@ -0,0 +1,58 @@ +package claude + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +type claudeNormalizedRequest struct { + Standard util.StandardRequest + NormalizedMessages []any +} + +func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + if _, ok := req["max_tokens"]; !ok { + req["max_tokens"] = 8192 + } + normalizedMessages := normalizeClaudeMessages(messagesRaw) + payload := cloneMap(req) + payload["messages"] = normalizedMessages + toolsRequested, _ := req["tools"].([]any) + if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) { + payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) + } + + dsPayload := util.ConvertClaudeToDeepSeek(payload, store) + dsModel, _ := dsPayload["model"].(string) + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) + if !ok { + thinkingEnabled = false + searchEnabled = false + } + finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + toolNames := extractClaudeToolNames(toolsRequested) + + return claudeNormalizedRequest{ + Standard: util.StandardRequest{ + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: payload["messages"].([]any), + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + }, + NormalizedMessages: normalizedMessages, + }, nil +} diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go new file mode 100644 index 0000000..7ffdfb8 --- /dev/null +++ b/internal/adapter/claude/standard_request_test.go @@ -0,0 +1,38 @@ +package claude + +import ( + "testing" + + "ds2api/internal/config" +) + +func TestNormalizeClaudeRequest(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "stream": true, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if norm.Standard.ResolvedModel == "" { + t.Fatalf("expected resolved model") + } + if !norm.Standard.Stream { + t.Fatalf("expected stream=true") + } + if len(norm.Standard.ToolNames) == 0 { + t.Fatalf("expected tool names") + } + if norm.Standard.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} diff --git a/internal/adapter/openai/embeddings_handler.go b/internal/adapter/openai/embeddings_handler.go new file mode 100644 index 0000000..ff61be0 --- /dev/null +++ b/internal/adapter/openai/embeddings_handler.go @@ -0,0 +1,138 @@ +package openai + +import ( + "crypto/sha256" + "encoding/binary" + "encoding/json" + "fmt" + "net/http" + "strings" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.") + return + } + if _, ok := config.ResolveModel(h.Store, model); !ok { + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) + return + } + + inputs := extractEmbeddingInputs(req["input"]) + if len(inputs) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.") + return + } + + provider := "" + if h.Store != nil { + provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider())) + } + if provider == "" { + writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.") + return + } + switch provider { + case "mock", "deterministic", "builtin": + // supported local deterministic provider + default: + writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider)) + return + } + + data := make([]map[string]any, 0, len(inputs)) + totalTokens := 0 + for i, input := range inputs { + totalTokens += util.EstimateTokens(input) + data = append(data, map[string]any{ + "object": "embedding", + "index": i, + "embedding": deterministicEmbedding(input), + }) + } + writeJSON(w, http.StatusOK, map[string]any{ + "object": "list", + "data": data, + "model": model, + "usage": map[string]any{ + "prompt_tokens": totalTokens, + "total_tokens": totalTokens, + }, + }) +} + +func extractEmbeddingInputs(raw any) []string { + switch v := raw.(type) { + case string: + s := strings.TrimSpace(v) + if s == "" { + return nil + } + return []string{s} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + switch iv := item.(type) { + case string: + s := strings.TrimSpace(iv) + if s != "" { + out = append(out, s) + } + case []any: + // Token array input support: convert to stable string form. + out = append(out, fmt.Sprintf("%v", iv)) + default: + s := strings.TrimSpace(fmt.Sprintf("%v", iv)) + if s != "" { + out = append(out, s) + } + } + } + return out + default: + return nil + } +} + +func deterministicEmbedding(input string) []float64 { + // Keep response shape stable without external dependencies. + const dims = 64 + out := make([]float64, dims) + seed := sha256.Sum256([]byte(input)) + buf := seed[:] + for i := 0; i < dims; i++ { + if len(buf) < 4 { + next := sha256.Sum256(buf) + buf = next[:] + } + v := binary.BigEndian.Uint32(buf[:4]) + buf = buf[4:] + // map [0, 2^32) -> [-1, 1] + out[i] = (float64(v)/2147483647.5 - 1.0) + } + return out +} diff --git a/internal/adapter/openai/embeddings_route_test.go b/internal/adapter/openai/embeddings_route_test.go new file mode 100644 index 0000000..4395d16 --- /dev/null +++ b/internal/adapter/openai/embeddings_route_test.go @@ -0,0 +1,96 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", cfgJSON) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func TestEmbeddingsRouteContract(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("unauthorized", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("ok", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + if out["object"] != "list" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + data, _ := out["data"].([]any) + if len(data) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(data)) + } + }) +} + +func TestEmbeddingsRouteProviderMissing(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code in response: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param in response: %#v", out) + } +} diff --git a/internal/adapter/openai/error_shape_test.go b/internal/adapter/openai/error_shape_test.go new file mode 100644 index 0000000..c169e04 --- /dev/null +++ b/internal/adapter/openai/error_shape_test.go @@ -0,0 +1,35 @@ +package openai + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) { + rec := httptest.NewRecorder() + writeOpenAIError(rec, http.StatusBadRequest, "invalid input") + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rec.Code) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + errObj, _ := body["error"].(map[string]any) + if errObj["message"] != "invalid input" { + t.Fatalf("unexpected message: %v", errObj["message"]) + } + if errObj["type"] != "invalid_request_error" { + t.Fatalf("unexpected type: %v", errObj["type"]) + } + if errObj["code"] != "invalid_request" { + t.Fatalf("unexpected code: %v", errObj["code"]) + } + if _, ok := errObj["param"]; !ok { + t.Fatal("expected param field") + } +} + diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index d0a2f1d..5ef6e7b 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -11,6 +11,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "ds2api/internal/auth" "ds2api/internal/config" @@ -30,6 +31,8 @@ type Handler struct { leaseMu sync.Mutex streamLeases map[string]streamLease + responsesMu sync.Mutex + responses *responseStore } type streamLease struct { @@ -39,13 +42,27 @@ type streamLease struct { func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/v1/models", h.ListModels) + r.Get("/v1/models/{model_id}", h.GetModel) r.Post("/v1/chat/completions", h.ChatCompletions) + r.Post("/v1/responses", h.Responses) + r.Get("/v1/responses/{response_id}", h.GetResponseByID) + r.Post("/v1/embeddings", h.Embeddings) } func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) } +func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) { + modelID := strings.TrimSpace(chi.URLParam(r, "model_id")) + model, ok := config.OpenAIModelByID(h.Store, modelID) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Model not found.") + return + } + writeJSON(w, http.StatusOK, model) +} + func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { if isVercelStreamReleaseRequest(r) { h.handleVercelStreamRelease(w, r) @@ -74,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) - if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - - messages := normalizeMessages(messagesRaw) - toolNames := []string{} - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, toolNames = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -107,27 +111,20 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } + payload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") return } - if util.ToBool(req["stream"]) { - h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) } -func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) @@ -139,36 +136,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - detected := util.ParseToolCalls(finalText, toolNames) - finishReason := "stop" - messageObj := map[string]any{"role": "assistant", "content": finalText} - if thinkingEnabled && finalThinking != "" { - messageObj["reasoning_content"] = finalThinking - } - if len(detected) > 0 { - finishReason = "tool_calls" - messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) - messageObj["content"] = nil - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - - writeJSON(w, http.StatusOK, map[string]any{ - "id": completionID, - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - }) + respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + writeJSON(w, http.StatusOK, respBody) } func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { @@ -190,9 +159,11 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt created := time.Now().Unix() firstChunkSent := false - bufferToolContent := len(toolNames) > 0 + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() var toolSieve toolStreamSieveState toolCallsEmitted := false + streamToolCallIDs := map[int]string{} initialType := "text" if thinkingEnabled { initialType = "thinking" @@ -235,13 +206,13 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt delta["role"] = "assistant" firstChunkSent = true } - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": delta, "index": 0}}, - }) + sendChunk(util.BuildOpenAIChatStreamChunk( + completionID, + created, + model, + []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, + nil, + )) } else if bufferToolContent { for _, evt := range flushToolSieve(&toolSieve, toolNames) { if evt.Content == "" { @@ -254,36 +225,25 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt delta["role"] = "assistant" firstChunkSent = true } - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": delta, "index": 0}}, - }) + sendChunk(util.BuildOpenAIChatStreamChunk( + completionID, + created, + model, + []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, + nil, + )) } } if len(detected) > 0 || toolCallsEmitted { finishReason = "tool_calls" } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - }) + sendChunk(util.BuildOpenAIChatStreamChunk( + completionID, + created, + model, + []map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)}, + util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText), + )) sendDone() } @@ -357,6 +317,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt // Keep thinking delta only frame. } for _, evt := range events { + if len(evt.ToolCallDeltas) > 0 { + if !emitEarlyToolDeltas { + continue + } + toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), + } + if !firstChunkSent { + tcDelta["role"] = "assistant" + firstChunkSent = true + } + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) + continue + } if len(evt.ToolCalls) > 0 { toolCallsEmitted = true tcDelta := map[string]any{ @@ -366,10 +341,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt tcDelta["role"] = "assistant" firstChunkSent = true } - newChoices = append(newChoices, map[string]any{ - "delta": tcDelta, - "index": 0, - }) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) continue } if evt.Content != "" { @@ -380,42 +352,22 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt contentDelta["role"] = "assistant" firstChunkSent = true } - newChoices = append(newChoices, map[string]any{ - "delta": contentDelta, - "index": 0, - }) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta)) } } } } if len(delta) > 0 { - newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0}) + newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta)) } } if len(newChoices) > 0 { - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": newChoices, - }) + sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil)) } } } } -func normalizeMessages(raw []any) []map[string]any { - out := make([]map[string]any, 0, len(raw)) - for _, item := range raw { - m, ok := item.(map[string]any) - if ok { - out = append(out, m) - } - } - return out -} - func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { toolSchemas := make([]string, 0, len(tools)) names := make([]string, 0, len(tools)) @@ -444,7 +396,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, if len(toolSchemas) == 0 { return messages, names } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }" + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error." for i := range messages { if messages[i]["role"] == "system" { @@ -457,11 +409,47 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, return messages, names } +func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any { + if len(deltas) == 0 { + return nil + } + out := make([]map[string]any, 0, len(deltas)) + for _, d := range deltas { + if d.Name == "" && d.Arguments == "" { + continue + } + callID, ok := ids[d.Index] + if !ok || callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + ids[d.Index] = callID + } + item := map[string]any{ + "index": d.Index, + "id": callID, + "type": "function", + } + fn := map[string]any{} + if d.Name != "" { + fn["name"] = d.Name + } + if d.Arguments != "" { + fn["arguments"] = d.Arguments + } + if len(fn) > 0 { + item["function"] = fn + } + out = append(out, item) + } + return out +} + func writeOpenAIError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]any{ "error": map[string]any{ "message": message, "type": openAIErrorType(status), + "code": openAIErrorCode(status), + "param": nil, }, }) } @@ -485,3 +473,47 @@ func openAIErrorType(status int) string { return "invalid_request_error" } } + +func openAIErrorCode(status int) string { + switch status { + case http.StatusBadRequest: + return "invalid_request" + case http.StatusUnauthorized: + return "authentication_failed" + case http.StatusForbidden: + return "forbidden" + case http.StatusTooManyRequests: + return "rate_limit_exceeded" + case http.StatusNotFound: + return "not_found" + case http.StatusServiceUnavailable: + return "service_unavailable" + default: + if status >= 500 { + return "internal_error" + } + return "invalid_request" + } +} + +func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { + for k, v := range collectOpenAIChatPassThrough(req) { + payload[k] = v + } +} + +func (h *Handler) toolcallFeatureMatchEnabled() bool { + if h == nil || h.Store == nil { + return true + } + mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode())) + return mode == "" || mode == "feature_match" +} + +func (h *Handler) toolcallEarlyEmitHighConfidence() bool { + if h == nil || h.Store == nil { + return true + } + level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence())) + return level == "" || level == "high" +} diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index f9c44dd..dd2bb0f 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string { return "" } +func streamToolCallArgumentChunks(frames []map[string]any) []string { + out := make([]string, 0, 4) + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + toolCalls, _ := delta["tool_calls"].([]any) + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + fn, _ := tcm["function"].(map[string]any) + if args, ok := fn["arguments"].(string); ok && args != "" { + out = append(out, args) + } + } + } + } + return out +} + func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -108,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -141,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -169,7 +189,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -190,6 +210,66 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { } } +func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"下面是示例:"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + toolCalls, _ := msg["tool_calls"].([]any) + if len(toolCalls) == 0 { + t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"]) + } + if msg["content"] != nil { + t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"]) + } +} + +func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}", + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected fenced tool example to pass through as text, got %q", content) + } +} + func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -377,9 +457,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) { func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( - `data: {"p":"response/content","v":"前置正文A。"}`, + `data: {"p":"response/content","v":"下面是示例:"}`, `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, - `data: {"p":"response/content","v":"后置正文B。"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, `data: [DONE]`, ) rec := httptest.NewRecorder() @@ -392,10 +472,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String()) - } - if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String()) + t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -409,9 +486,95 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { } } got := content.String() - if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") { + if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } + if strings.Contains(strings.ToLower(got), `"tool_calls"`) { + t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"我将调用工具。"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + content := strings.Builder{} + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if c, ok := delta["content"].(string); ok { + content.WriteString(c) + } + } + } + got := content.String() + if !strings.Contains(got, "我将调用工具。") { + t.Fatalf("expected leading text to keep streaming, got=%q", got) + } + if strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("unexpected raw tool json leak, got=%q", got) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + content := strings.Builder{} + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if c, ok := delta["content"].(string); ok { + content.WriteString(c) + } + } + } + got := content.String() + if !strings.Contains(got, "接下来我会继续说明。") { + t.Fatalf("expected trailing plain text to be preserved, got=%q", got) + } + if strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("unexpected raw tool json leak, got=%q", got) + } if streamFinishReason(frames) != "tool_calls" { t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } @@ -495,16 +658,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) { } } } - got := strings.ToLower(content.String()) - if strings.Contains(got, "tool_calls") { - t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String()) - } - if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") { + got := content.String() + if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") { t.Fatalf("expected pre/post plain text to remain, got=%q", content.String()) } + if !strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got) + } } -func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) { +func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, @@ -533,7 +696,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing. } } } - if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") { - t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String()) + if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") { + t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String()) + } +} + +func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`, + `data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + argChunks := streamToolCallArgumentChunks(frames) + if len(argChunks) < 2 { + t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String()) + } + joined := strings.Join(argChunks, "") + if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) { + t.Fatalf("unexpected merged arguments stream: %q", joined) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go new file mode 100644 index 0000000..3ebd1e7 --- /dev/null +++ b/internal/adapter/openai/message_normalize.go @@ -0,0 +1,192 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" +) + +func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + msg, ok := item.(map[string]any) + if !ok { + continue + } + role := strings.ToLower(strings.TrimSpace(asString(msg["role"]))) + switch role { + case "assistant": + content := normalizeOpenAIContentForPrompt(msg["content"]) + toolCalls := formatAssistantToolCallsForPrompt(msg) + combined := joinNonEmpty(content, toolCalls) + if combined == "" { + continue + } + out = append(out, map[string]any{ + "role": "assistant", + "content": combined, + }) + case "tool", "function": + out = append(out, map[string]any{ + "role": "user", + "content": formatToolResultForPrompt(msg), + }) + case "user", "system": + out = append(out, map[string]any{ + "role": role, + "content": normalizeOpenAIContentForPrompt(msg["content"]), + }) + default: + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + continue + } + if role == "" { + role = "user" + } + out = append(out, map[string]any{ + "role": role, + "content": content, + }) + } + } + return out +} + +func formatAssistantToolCallsForPrompt(msg map[string]any) string { + entries := make([]string, 0) + if calls, ok := msg["tool_calls"].([]any); ok { + for i, item := range calls { + call, ok := item.(map[string]any) + if !ok { + continue + } + id := strings.TrimSpace(asString(call["id"])) + if id == "" { + id = fmt.Sprintf("call_%d", i+1) + } + name := strings.TrimSpace(asString(call["name"])) + args := "" + + if fn, ok := call["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) + } + if name == "" { + name = "unknown" + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["input"]) + } + if args == "" { + args = "{}" + } + entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args)) + } + } + + if legacy, ok := msg["function_call"].(map[string]any); ok { + name := strings.TrimSpace(asString(legacy["name"])) + if name == "" { + name = "unknown" + } + args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"]) + if args == "" { + args = "{}" + } + entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args)) + } + + return strings.Join(entries, "\n\n") +} + +func formatToolResultForPrompt(msg map[string]any) string { + toolCallID := strings.TrimSpace(asString(msg["tool_call_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(asString(msg["id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + + name := strings.TrimSpace(asString(msg["name"])) + if name == "" { + name = "unknown" + } + + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + content = "null" + } + + return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content) +} + +func normalizeOpenAIContentForPrompt(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + t := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + if t != "text" && t != "output_text" && t != "input_text" { + continue + } + if text := asString(m["text"]); text != "" { + parts = append(parts, text) + continue + } + if text := asString(m["content"]); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + default: + return marshalToPromptString(v) + } +} + +func normalizeOpenAIArgumentsForPrompt(v any) string { + switch x := v.(type) { + case string: + return strings.TrimSpace(x) + default: + return marshalToPromptString(v) + } +} + +func marshalToPromptString(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} + +func asString(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func joinNonEmpty(parts ...string) string { + nonEmpty := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p) == "" { + continue + } + nonEmpty = append(nonEmpty, p) + } + return strings.Join(nonEmpty, "\n\n") +} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go new file mode 100644 index 0000000..bb648d3 --- /dev/null +++ b/internal/adapter/openai/message_normalize_test.go @@ -0,0 +1,121 @@ +package openai + +import ( + "strings" + "testing" + + "ds2api/internal/util" +) + +func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) { + raw := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "content": nil, + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": "{\"temp\":18}", + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + if len(normalized) != 4 { + t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) + } + assistantContent, _ := normalized[2]["content"].(string) + if !strings.Contains(assistantContent, "tool_call_id: call_1") || + !strings.Contains(assistantContent, "function.name: get_weather") || + !strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") { + t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent) + } + toolContent, _ := normalized[3]["content"].(string) + if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") { + t.Fatalf("tool result not serialized correctly: %q", toolContent) + } + + prompt := util.MessagesPrepare(normalized) + if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") { + t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_2", + "name": "get_weather", + "content": map[string]any{ + "temp": 18, + "condition": "sunny", + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) { + t.Fatalf("expected serialized object in tool content, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_3", + "name": "read_file", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "output_text", "text": "line-2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "line-1\nline-2") { + t.Fatalf("expected joined text blocks, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "function", + "tool_call_id": "call_4", + "name": "legacy_tool", + "content": map[string]any{ + "ok": true, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + if len(normalized) != 1 { + t.Fatalf("expected one normalized message, got %d", len(normalized)) + } + if normalized[0]["role"] != "user" { + t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"]) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { + t.Fatalf("unexpected normalized function-role content: %q", got) + } +} diff --git a/internal/adapter/openai/models_route_test.go b/internal/adapter/openai/models_route_test.go new file mode 100644 index 0000000..1ba3382 --- /dev/null +++ b/internal/adapter/openai/models_route_test.go @@ -0,0 +1,46 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestGetModelRouteDirectAndAlias(t *testing.T) { + h := &Handler{} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("direct", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("alias", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String()) + } + }) +} + +func TestGetModelRouteNotFound(t *testing.T) { + h := &Handler{} + r := chi.NewRouter() + RegisterRoutes(r, h) + + req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go new file mode 100644 index 0000000..a7bbc92 --- /dev/null +++ b/internal/adapter/openai/prompt_build.go @@ -0,0 +1,12 @@ +package openai + +import "ds2api/internal/util" + +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { + messages := normalizeOpenAIMessagesForPrompt(messagesRaw) + toolNames := []string{} + if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { + messages, toolNames = injectToolPrompt(messages, tools) + } + return util.MessagesPrepare(messages), toolNames +} diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go new file mode 100644 index 0000000..1833860 --- /dev/null +++ b/internal/adapter/openai/prompt_build_test.go @@ -0,0 +1,80 @@ +package openai + +import ( + "strings" + "testing" +) + +func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) { + messages := []any{ + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": map[string]any{"temp": 18, "condition": "sunny"}, + }, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "description": "Get weather", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools) + if len(toolNames) != 1 || toolNames[0] != "get_weather" { + t.Fatalf("unexpected tool names: %#v", toolNames) + } + if !strings.Contains(finalPrompt, "tool_call_id: call_1") || + !strings.Contains(finalPrompt, "function.name: get_weather") || + !strings.Contains(finalPrompt, "Tool result:") || + !strings.Contains(finalPrompt, `"condition":"sunny"`) { + t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt) + } +} + +func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) { + messages := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "请调用工具"}, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search docs", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools) + if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") { + t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) + } + if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") { + t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) + } +} diff --git a/internal/adapter/openai/response_store.go b/internal/adapter/openai/response_store.go new file mode 100644 index 0000000..63ebbaa --- /dev/null +++ b/internal/adapter/openai/response_store.go @@ -0,0 +1,109 @@ +package openai + +import ( + "sync" + "time" + + "ds2api/internal/auth" +) + +type storedResponse struct { + Owner string + Value map[string]any + ExpiresAt time.Time +} + +type responseStore struct { + mu sync.Mutex + ttl time.Duration + items map[string]storedResponse +} + +func newResponseStore(ttl time.Duration) *responseStore { + if ttl <= 0 { + ttl = 15 * time.Minute + } + return &responseStore{ + ttl: ttl, + items: make(map[string]storedResponse), + } +} + +func responseStoreKey(owner, id string) string { + return owner + "\x00" + id +} + +func responseStoreOwner(a *auth.RequestAuth) string { + if a == nil { + return "" + } + return a.CallerID +} + +func (s *responseStore) put(owner, id string, value map[string]any) { + if s == nil || owner == "" || id == "" || value == nil { + return + } + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked(now) + s.items[responseStoreKey(owner, id)] = storedResponse{ + Owner: owner, + Value: cloneAnyMap(value), + ExpiresAt: now.Add(s.ttl), + } +} + +func (s *responseStore) get(owner, id string) (map[string]any, bool) { + if s == nil || owner == "" || id == "" { + return nil, false + } + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked(now) + item, ok := s.items[responseStoreKey(owner, id)] + if !ok { + return nil, false + } + if item.Owner != owner { + return nil, false + } + return cloneAnyMap(item.Value), true +} + +func (s *responseStore) sweepLocked(now time.Time) { + for k, v := range s.items { + if now.After(v.ExpiresAt) { + delete(s.items, k) + } + } +} + +func cloneAnyMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func (h *Handler) getResponseStore() *responseStore { + if h == nil { + return nil + } + h.responsesMu.Lock() + defer h.responsesMu.Unlock() + if h.responses == nil { + ttl := 15 * time.Minute + if h.Store != nil { + ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second + } + h.responses = newResponseStore(ttl) + } + return h.responses +} diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go new file mode 100644 index 0000000..d270e1a --- /dev/null +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -0,0 +1,73 @@ +package openai + +import ( + "testing" + "time" +) + +func TestNormalizeResponsesInputAsMessagesString(t *testing.T) { + msgs := normalizeResponsesInputAsMessages("hello") + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "user" || m["content"] != "hello" { + t.Fatalf("unexpected message: %#v", m) + } +} + +func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) { + req := map[string]any{ + "model": "gpt-4.1", + "input": "ping", + "instructions": "system text", + } + msgs := responsesMessagesFromRequest(req) + if len(msgs) != 2 { + t.Fatalf("expected two messages, got %d", len(msgs)) + } + sys, _ := msgs[0].(map[string]any) + if sys["role"] != "system" { + t.Fatalf("unexpected first message: %#v", sys) + } +} + +func TestExtractEmbeddingInputs(t *testing.T) { + got := extractEmbeddingInputs([]any{"a", "b"}) + if len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Fatalf("unexpected inputs: %#v", got) + } +} + +func TestDeterministicEmbeddingStable(t *testing.T) { + a := deterministicEmbedding("hello") + b := deterministicEmbedding("hello") + if len(a) != 64 || len(b) != 64 { + t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b)) + } + for i := range a { + if a[i] != b[i] { + t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i]) + } + } +} + +func TestResponseStorePutGet(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"}) + got, ok := st.get("owner_1", "resp_1") + if !ok { + t.Fatal("expected stored response") + } + if got["id"] != "resp_1" { + t.Fatalf("unexpected response payload: %#v", got) + } +} + +func TestResponseStoreTenantIsolation(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"}) + if _, ok := st.get("owner_b", "resp_1"); ok { + t.Fatal("expected owner_b to be isolated from owner_a response") + } +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go new file mode 100644 index 0000000..e04fb5f --- /dev/null +++ b/internal/adapter/openai/responses_handler.go @@ -0,0 +1,308 @@ +package openai + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + + "ds2api/internal/auth" + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.DetermineCaller(r) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, err.Error()) + return + } + + id := strings.TrimSpace(chi.URLParam(r, "response_id")) + if id == "" { + writeOpenAIError(w, http.StatusBadRequest, "response_id is required.") + return + } + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } + st := h.getResponseStore() + item, ok := st.get(owner, id) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Response not found.") + return + } + writeJSON(w, http.StatusOK, item) +} + +func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + r = r.WithContext(auth.WithAuth(r.Context(), a)) + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if stdReq.Stream { + h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + result := sse.CollectStream(resp, thinkingEnabled, true) + responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + h.getResponseStore().put(owner, responseID, responseObj) + writeJSON(w, http.StatusOK, responseObj) +} + +func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + canFlush := rc.Flush() == nil + + sendEvent := func(event string, payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = w.Write([]byte("event: " + event + "\n")) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(b) + _, _ = w.Write([]byte("\n\n")) + if canFlush { + _ = rc.Flush() + } + } + + sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model)) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() + var sieve toolStreamSieveState + thinking := strings.Builder{} + text := strings.Builder{} + toolCallsEmitted := false + streamToolCallIDs := map[int]string{} + + finalize := func() { + finalThinking := thinking.String() + finalText := text.String() + if bufferToolContent { + for _, evt := range flushToolSieve(&sieve, toolNames) { + if evt.Content != "" { + sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) + } + if len(evt.ToolCalls) > 0 { + toolCallsEmitted = true + sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + } + } + } + obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames) + if toolCallsEmitted { + obj["status"] = "completed" + } + h.getResponseStore().put(owner, responseID, obj) + sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj)) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + if canFlush { + _ = rc.Flush() + } + } + + for { + select { + case <-r.Context().Done(): + return + case parsed, ok := <-parsedLines: + if !ok { + _ = <-done + finalize() + return + } + if !parsed.Parsed { + continue + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + finalize() + return + } + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Type == "thinking" { + if !thinkingEnabled { + continue + } + thinking.WriteString(p.Text) + sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text)) + continue + } + text.WriteString(p.Text) + if !bufferToolContent { + sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text)) + continue + } + for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) { + if evt.Content != "" { + sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) + } + if len(evt.ToolCallDeltas) > 0 { + if !emitEarlyToolDeltas { + continue + } + toolCallsEmitted = true + sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs))) + } + if len(evt.ToolCalls) > 0 { + toolCallsEmitted = true + sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + } + } + } + } + } +} + +func responsesMessagesFromRequest(req map[string]any) []any { + if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + if rawInput, ok := req["input"]; ok { + if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + } + return nil +} + +func prependInstructionMessage(messages []any, instructions any) []any { + sys, _ := instructions.(string) + sys = strings.TrimSpace(sys) + if sys == "" { + return messages + } + out := make([]any, 0, len(messages)+1) + out = append(out, map[string]any{"role": "system", "content": sys}) + out = append(out, messages...) + return out +} + +func normalizeResponsesInputAsMessages(input any) []any { + switch v := input.(type) { + case string: + if strings.TrimSpace(v) == "" { + return nil + } + return []any{map[string]any{"role": "user", "content": v}} + case []any: + if len(v) == 0 { + return nil + } + // If caller already provides role-shaped items, keep as-is. + if first, ok := v[0].(map[string]any); ok { + if _, hasRole := first["role"]; hasRole { + return v + } + } + parts := make([]string, 0, len(v)) + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + parts = append(parts, txt) + continue + } + } + } + if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { + parts = append(parts, s) + } + } + if len(parts) == 0 { + return nil + } + return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}} + case map[string]any: + if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { + return []any{map[string]any{"role": "user", "content": txt}} + } + if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" { + return []any{map[string]any{"role": "user", "content": content}} + } + } + return nil +} diff --git a/internal/adapter/openai/responses_route_test.go b/internal/adapter/openai/responses_route_test.go new file mode 100644 index 0000000..574c6fa --- /dev/null +++ b/internal/adapter/openai/responses_route_test.go @@ -0,0 +1,176 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}] + }`) + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0") + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer "+token) + a, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine auth failed: %v", err) + } + return a +} + +func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerA := responseStoreOwner(authForToken(t, resolver, "token-a")) + h.getResponseStore().put(ownerA, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + t.Run("unauthorized", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("cross-tenant-not-found", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-b") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("same-tenant-ok", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-a") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body failed: %v", err) + } + if body["id"] != "resp_test" { + t.Fatalf("unexpected body: %#v", body) + } + }) +} + +func TestResponsesRouteValidationContract(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + tests := []struct { + name string + body string + }{ + {name: "missing_model", body: `{"input":"hello"}`}, + {name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body)) + req.Header.Set("Authorization", "Bearer token-a") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param: %#v", out) + } + }) + } +} + +func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) { + store, resolver := newManagedKeyResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + ownerReq.Header.Set("Authorization", "Bearer managed-key") + ownerAuth, err := resolver.DetermineCaller(ownerReq) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + owner := responseStoreOwner(ownerAuth) + h.getResponseStore().put(owner, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + occupyReq.Header.Set("Authorization", "Bearer managed-key") + occupied, err := resolver.Determine(occupyReq) + if err != nil { + t.Fatalf("expected first acquire to succeed: %v", err) + } + defer resolver.Release(occupied) + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer managed-key") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go new file mode 100644 index 0000000..9b0a5ac --- /dev/null +++ b/internal/adapter/openai/responses_stream_test.go @@ -0,0 +1,122 @@ +package openai + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}` + streamBody := sseLine(rawToolJSON) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + outputText, _ := responseObj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText) + } + output, _ := responseObj["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected structured output entries, got %#v", responseObj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "tool_calls" { + t.Fatalf("expected first output type tool_calls, got %#v", first["type"]) + } + toolCalls, _ := first["tool_calls"].([]any) + if len(toolCalls) == 0 { + t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"]) + } + call0, _ := toolCalls[0].(map[string]any) + if call0["name"] != "read_file" { + t.Fatalf("unexpected tool call name: %#v", call0["name"]) + } + if strings.Contains(outputText, `"tool_calls"`) { + t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText) + } +} + +func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + tail := `{"tool_calls":[{"name":"read_file","input":` + streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + outputText, _ := responseObj["output_text"].(string) + if strings.Count(outputText, tail) > 1 { + t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText) + } +} + +func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { + scanner := bufio.NewScanner(strings.NewReader(body)) + matched := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "event: ") { + evt := strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + matched = evt == targetEvent + continue + } + if !matched || !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" || raw == "[DONE]" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return nil, false + } + return payload, true + } + return nil, false +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go new file mode 100644 index 0000000..52344d4 --- /dev/null +++ b/internal/adapter/openai/standard_request.go @@ -0,0 +1,104 @@ +package openai + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + responseModel := strings.TrimSpace(model) + if responseModel == "" { + responseModel = resolvedModel + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_chat", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: resolvedModel, + ResponseModel: responseModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + // Keep width-control as an explicit policy hook even if current default is true. + allowWideInput := true + if store != nil { + allowWideInput = store.CompatWideInputStrictOutput() + } + var messagesRaw []any + if allowWideInput { + messagesRaw = responsesMessagesFromRequest(req) + } else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + messagesRaw = msgs + } + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_responses", + RequestedModel: model, + ResolvedModel: resolvedModel, + ResponseModel: model, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func collectOpenAIChatPassThrough(req map[string]any) map[string]any { + out := map[string]any{} + for _, k := range []string{ + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "presence_penalty", + "frequency_penalty", + "stop", + } { + if v, ok := req[k]; ok { + out[k] = v + } + } + return out +} diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go new file mode 100644 index 0000000..f3453a2 --- /dev/null +++ b/internal/adapter/openai/standard_request_test.go @@ -0,0 +1,60 @@ +package openai + +import ( + "testing" + + "ds2api/internal/config" +) + +func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{}`) + return config.LoadStore() +} + +func TestNormalizeOpenAIChatRequest(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-5-codex", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "temperature": 0.3, + "stream": true, + } + n, err := normalizeOpenAIChatRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-reasoner" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if !n.Stream { + t.Fatalf("expected stream=true") + } + if _, ok := n.PassThrough["temperature"]; !ok { + t.Fatalf("expected temperature passthrough") + } + if n.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} + +func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "instructions": "system", + } + n, err := normalizeOpenAIResponsesRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-chat" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if len(n.Messages) != 2 { + t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages)) + } +} diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index d1a9014..fd7222b 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -7,14 +7,40 @@ import ( ) type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool + pending strings.Builder + capture strings.Builder + capturing bool + recentTextTail string + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool } type toolStreamEvent struct { - Content string - ToolCalls []util.ParsedToolCall + Content string + ToolCalls []util.ParsedToolCall + ToolCallDeltas []toolCallDelta +} + +type toolCallDelta struct { + Index int + Name string + Arguments string +} + +const toolSieveCaptureLimit = 8 * 1024 +const toolSieveContextTailLimit = 256 + +func (s *toolStreamSieveState) resetIncrementalToolState() { + s.toolNameSent = false + s.toolName = "" + s.toolArgsStart = -1 + s.toolArgsSent = -1 + s.toolArgsString = false + s.toolArgsDone = false } func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { @@ -32,13 +58,27 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capture.WriteString(state.pending.String()) state.pending.Reset() } - prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames) + if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { + events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) + } + prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) if !ready { + if state.capture.Len() > toolSieveCaptureLimit { + content := state.capture.String() + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + continue + } break } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() if prefix != "" { + state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } if len(calls) > 0 { @@ -58,11 +98,13 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames if start >= 0 { prefix := pending[:start] if prefix != "" { + state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } state.pending.Reset() state.capture.WriteString(pending[start:]) state.capturing = true + state.resetIncrementalToolState() continue } @@ -72,6 +114,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames } state.pending.Reset() state.pending.WriteString(hold) + state.noteText(safe) events = append(events, toolStreamEvent{Content: safe}) } @@ -84,25 +127,34 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea } events := processToolSieveChunk(state, "", toolNames) if state.capturing { - consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames) + consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) if ready { if consumedPrefix != "" { + state.noteText(consumedPrefix) events = append(events, toolStreamEvent{Content: consumedPrefix}) } if len(consumedCalls) > 0 { events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) } if consumedSuffix != "" { + state.noteText(consumedSuffix) events = append(events, toolStreamEvent{Content: consumedSuffix}) } } else { - // Incomplete captured tool JSON at stream end: suppress raw capture. + content := state.capture.String() + if content != "" { + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + } } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() } if state.pending.Len() > 0 { - events = append(events, toolStreamEvent{Content: state.pending.String()}) + content := state.pending.String() + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) state.pending.Reset() } return events @@ -144,17 +196,26 @@ func findToolSegmentStart(s string) int { return -1 } lower := strings.ToLower(s) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - return -1 + offset := 0 + for { + keyRel := strings.Index(lower[offset:], "tool_calls") + if keyRel < 0 { + return -1 + } + keyIdx := offset + keyRel + start := strings.LastIndex(s[:keyIdx], "{") + if start < 0 { + start = keyIdx + } + if !insideCodeFence(s[:start]) { + return start + } + offset = keyIdx + len("tool_calls") } - if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 { - return start - } - return keyIdx } -func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { +func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { + captured := state.capture.String() if captured == "" { return "", nil, "", false } @@ -171,13 +232,25 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal if !ok { return "", nil, "", false } - parsed := util.ParseToolCalls(obj, toolNames) - if len(parsed) == 0 { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. - return captured[:start], nil, captured[end:], true + prefixPart := captured[:start] + suffixPart := captured[end:] + if insideCodeFence(state.recentTextTail + prefixPart) { + return captured, nil, "", true } - return captured[:start], parsed, captured[end:], true + parsed := util.ParseStandaloneToolCalls(obj, toolNames) + if len(parsed) == 0 { + if state.toolNameSent { + return prefixPart, nil, suffixPart, true + } + return captured, nil, "", true + } + if state.toolNameSent { + if len(parsed) > 1 { + return prefixPart, parsed[1:], suffixPart, true + } + return prefixPart, nil, suffixPart, true + } + return prefixPart, parsed, suffixPart, true } func extractJSONObjectFrom(text string, start int) (string, int, bool) { @@ -221,3 +294,352 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) { } return "", 0, false } + +func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + captured := state.capture.String() + if captured == "" { + return nil + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return nil + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 { + return nil + } + if insideCodeFence(state.recentTextTail + captured[:start]) { + return nil + } + callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) + if !ok { + return nil + } + deltas := make([]toolCallDelta, 0, 2) + if state.toolName == "" { + name, ok := extractToolCallName(captured, callStart) + if !ok || name == "" { + return nil + } + state.toolName = name + } + if state.toolArgsStart < 0 { + argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart) + if ok { + state.toolArgsString = stringMode + if stringMode { + state.toolArgsStart = argsStart + 1 + } else { + state.toolArgsStart = argsStart + } + state.toolArgsSent = state.toolArgsStart + } + } + if !state.toolNameSent { + if state.toolArgsStart < 0 { + return nil + } + state.toolNameSent = true + deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName}) + } + if state.toolArgsStart < 0 || state.toolArgsDone { + return deltas + } + end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString) + if !ok { + return deltas + } + if end > state.toolArgsSent { + deltas = append(deltas, toolCallDelta{ + Index: 0, + Arguments: captured[state.toolArgsSent:end], + }) + state.toolArgsSent = end + } + if complete { + state.toolArgsDone = true + } + return deltas +} + +func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return -1, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return -1, false + } + return i, true +} + +func findToolCallsArrayStart(text string, keyIdx int) (int, bool) { + i := keyIdx + len("tool_calls") + for i < len(text) && text[i] != ':' { + i++ + } + if i >= len(text) { + return -1, false + } + i = skipSpaces(text, i+1) + if i >= len(text) || text[i] != '[' { + return -1, false + } + return i, true +} + +func extractToolCallName(text string, callStart int) (string, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return "", false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + return "", false + } + } + name, _, ok := parseJSONStringLiteral(text, valueStart) + if !ok { + return "", false + } + return name, true +} + +func findToolCallArgsStart(text string, callStart int) (int, bool, bool) { + keys := []string{"input", "arguments", "args", "parameters", "params"} + valueStart, ok := findObjectFieldValueStart(text, callStart, keys) + if !ok { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return -1, false, false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, keys) + if !ok { + return -1, false, false + } + } + if valueStart >= len(text) { + return -1, false, false + } + ch := text[valueStart] + if ch == '{' || ch == '[' { + return valueStart, false, true + } + if ch == '"' { + return valueStart, true, true + } + return -1, false, false +} + +func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) { + if start < 0 || start > len(text) { + return 0, false, false + } + if stringMode { + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return i, true, true + } + } + return len(text), false, true + } + if start >= len(text) { + return start, false, false + } + if text[start] != '{' && text[start] != '[' { + return 0, false, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' || ch == '[' { + depth++ + continue + } + if ch == '}' || ch == ']' { + depth-- + if depth == 0 { + return i + 1, true, true + } + } + } + return len(text), false, true +} + +func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) { + if objStart < 0 || objStart >= len(text) || text[objStart] != '{' { + return 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := objStart; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + if depth == 1 { + key, end, ok := parseJSONStringLiteral(text, i) + if !ok { + return 0, false + } + j := skipSpaces(text, end) + if j >= len(text) || text[j] != ':' { + i = end - 1 + continue + } + j = skipSpaces(text, j+1) + if j >= len(text) { + return 0, false + } + if containsKey(keys, key) { + return j, true + } + i = j - 1 + continue + } + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + break + } + } + } + return 0, false +} + +func findFunctionObjectStart(text string, callStart int) (int, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"}) + if !ok || valueStart >= len(text) || text[valueStart] != '{' { + return -1, false + } + return valueStart, true +} + +func parseJSONStringLiteral(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '"' { + return "", 0, false + } + var b strings.Builder + escaped := false + for i := start + 1; i < len(text); i++ { + ch := text[i] + if escaped { + b.WriteByte(ch) + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return b.String(), i + 1, true + } + b.WriteByte(ch) + } + return "", 0, false +} + +func containsKey(keys []string, value string) bool { + for _, k := range keys { + if k == value { + return true + } + } + return false +} + +func skipSpaces(text string, i int) int { + for i < len(text) { + switch text[i] { + case ' ', '\t', '\n', '\r': + i++ + default: + return i + } + } + return i +} + +func (s *toolStreamSieveState) noteText(content string) { + if strings.TrimSpace(content) == "" { + return + } + s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) +} + +func appendTail(prev, next string, max int) string { + if max <= 0 { + return "" + } + combined := prev + next + if len(combined) <= max { + return combined + } + return combined[len(combined)-max:] +} + +func looksLikeToolExampleContext(text string) bool { + return insideCodeFence(text) +} + +func insideCodeFence(text string) bool { + if text == "" { + return false + } + return strings.Count(text, "```")%2 == 1 +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 653f3cf..65006c4 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -56,24 +56,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) - if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) + if !stdReq.Stream { + writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - messages := normalizeMessages(messagesRaw) - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, _ = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { if a.UseConfigToken { @@ -93,14 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } + payload := stdReq.CompletionPayload(sessionID) leaseID := h.holdStreamLease(a) if leaseID == "" { writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") @@ -108,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque } leased = true writeJSON(w, http.StatusOK, map[string]any{ - "session_id": sessionID, - "lease_id": leaseID, - "model": model, - "final_prompt": finalPrompt, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - "deepseek_token": a.DeepSeekToken, - "pow_header": powHeader, - "payload": payload, + "session_id": sessionID, + "lease_id": leaseID, + "model": stdReq.ResponseModel, + "final_prompt": stdReq.FinalPrompt, + "thinking_enabled": stdReq.Thinking, + "search_enabled": stdReq.Search, + "tool_names": stdReq.ToolNames, + "toolcall_feature_match": h.toolcallFeatureMatchEnabled(), + "toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(), + "deepseek_token": a.DeepSeekToken, + "pow_header": powHeader, + "payload": payload, }) } diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts.go index b95077d..5cb88cc 100644 --- a/internal/admin/handler_accounts.go +++ b/internal/admin/handler_accounts.go @@ -56,7 +56,14 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { preview = token } } - items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview}) + items = append(items, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": acc.Password != "", + "has_token": token != "", + "token_preview": preview, + }) } writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) } @@ -94,7 +101,7 @@ func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { err := h.Store.Update(func(c *config.Config) error { idx := -1 for i, a := range c.Accounts { - if a.Email == identifier || a.Mobile == identifier { + if accountMatchesIdentifier(a, identifier) { idx = i break } @@ -122,10 +129,10 @@ func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { _ = json.NewDecoder(r.Body).Decode(&req) identifier, _ := req["identifier"].(string) if strings.TrimSpace(identifier) == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"}) + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"}) return } - acc, ok := h.Store.FindAccount(identifier) + acc, ok := findAccountByIdentifier(h.Store, identifier) if !ok { writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"}) return diff --git a/internal/admin/handler_accounts_identifier_test.go b/internal/admin/handler_accounts_identifier_test.go new file mode 100644 index 0000000..591d43a --- /dev/null +++ b/internal/admin/handler_accounts_identifier_test.go @@ -0,0 +1,138 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +func newAdminTestHandler(t *testing.T, raw string) *Handler { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", raw) + t.Setenv("CONFIG_JSON", "") + store := config.LoadStore() + return &Handler{ + Store: store, + Pool: account.NewPool(store), + } +} + +func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"token":"token-only-account"}] + }`) + + req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil) + rec := httptest.NewRecorder() + h.listAccounts(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response failed: %v", err) + } + items, _ := payload["items"].([]any) + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } + first, _ := items[0].(map[string]any) + identifier, _ := first["identifier"].(string) + if identifier == "" { + t.Fatalf("expected non-empty identifier: %#v", first) + } + if !strings.HasPrefix(identifier, "token:") { + t.Fatalf("expected token synthetic identifier, got %q", identifier) + } +} + +func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"token":"token-only-account"}] + }`) + accounts := h.Store.Accounts() + if len(accounts) != 1 { + t.Fatalf("expected 1 account, got %d", len(accounts)) + } + id := accounts[0].Identifier() + if id == "" { + t.Fatal("expected token-only synthetic identifier") + } + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestDeleteAccountSupportsMobileAlias(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"email":"u@example.com","mobile":"13800138000","password":"pwd"}] + }`) + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/13800138000", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[ + {"email":"u@example.com","mobile":"13800138000","password":"pwd"}, + {"token":"token-only-account"} + ] + }`) + + accByMobile, ok := findAccountByIdentifier(h.Store, "13800138000") + if !ok { + t.Fatal("expected find by mobile") + } + if accByMobile.Email != "u@example.com" { + t.Fatalf("unexpected account by mobile: %#v", accByMobile) + } + + tokenOnlyID := "" + for _, acc := range h.Store.Accounts() { + if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" { + tokenOnlyID = acc.Identifier() + break + } + } + if tokenOnlyID == "" { + t.Fatal("expected token-only account identifier") + } + accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID) + if !ok { + t.Fatalf("expected find by token-only id=%q", tokenOnlyID) + } + if accByTokenOnly.Token != "token-only-account" { + t.Fatalf("unexpected token-only account: %#v", accByTokenOnly) + } +} diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go index 7627602..2b672c3 100644 --- a/internal/admin/handler_config.go +++ b/internal/admin/handler_config.go @@ -37,6 +37,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { } } accounts = append(accounts, map[string]any{ + "identifier": acc.Identifier(), "email": acc.Email, "mobile": acc.Mobile, "has_password": strings.TrimSpace(acc.Password) != "", diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index fa75b59..d7d1198 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -81,3 +81,34 @@ func statusOr(v int, d int) int { } return v } + +func accountMatchesIdentifier(acc config.Account, identifier string) bool { + id := strings.TrimSpace(identifier) + if id == "" { + return false + } + if strings.TrimSpace(acc.Email) == id { + return true + } + if strings.TrimSpace(acc.Mobile) == id { + return true + } + return acc.Identifier() == id +} + +func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) { + id := strings.TrimSpace(identifier) + if id == "" { + return config.Account{}, false + } + if acc, ok := store.FindAccount(id); ok { + return acc, true + } + accounts := store.Snapshot().Accounts + for _, acc := range accounts { + if accountMatchesIdentifier(acc, id) { + return acc, true + } + } + return config.Account{}, false +} diff --git a/internal/admin/helpers_edge_test.go b/internal/admin/helpers_edge_test.go new file mode 100644 index 0000000..2a0bf20 --- /dev/null +++ b/internal/admin/helpers_edge_test.go @@ -0,0 +1,240 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "ds2api/internal/config" +) + +// ─── reverseAccounts ───────────────────────────────────────────────── + +func TestReverseAccountsEmpty(t *testing.T) { + a := []config.Account{} + reverseAccounts(a) + if len(a) != 0 { + t.Fatal("expected empty") + } +} + +func TestReverseAccountsTwoElements(t *testing.T) { + a := []config.Account{ + {Email: "a@test.com"}, + {Email: "b@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" { + t.Fatalf("unexpected order after reverse: %v", a) + } +} + +func TestReverseAccountsThreeElements(t *testing.T) { + a := []config.Account{ + {Email: "1@test.com"}, + {Email: "2@test.com"}, + {Email: "3@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" { + t.Fatalf("unexpected order: %v", a) + } +} + +// ─── intFromQuery edge cases ───────────────────────────────────────── + +func TestIntFromQueryPresent(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=5", nil) + if got := intFromQuery(req, "limit", 10); got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +func TestIntFromQueryMissing(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10, got %d", got) + } +} + +func TestIntFromQueryInvalid(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=abc", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10 for invalid, got %d", got) + } +} + +func TestIntFromQueryNegative(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=-3", nil) + if got := intFromQuery(req, "limit", 10); got != -3 { + t.Fatalf("expected -3, got %d", got) + } +} + +func TestIntFromQueryZero(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=0", nil) + if got := intFromQuery(req, "limit", 10); got != 0 { + t.Fatalf("expected 0, got %d", got) + } +} + +// ─── nilIfEmpty ────────────────────────────────────────────────────── + +func TestNilIfEmptyEmpty(t *testing.T) { + if nilIfEmpty("") != nil { + t.Fatal("expected nil for empty string") + } +} + +func TestNilIfEmptyNonEmpty(t *testing.T) { + if nilIfEmpty("hello") != "hello" { + t.Fatal("expected 'hello'") + } +} + +// ─── nilIfZero ─────────────────────────────────────────────────────── + +func TestNilIfZeroZero(t *testing.T) { + if nilIfZero(0) != nil { + t.Fatal("expected nil for zero") + } +} + +func TestNilIfZeroNonZero(t *testing.T) { + if nilIfZero(42) != int64(42) { + t.Fatal("expected 42") + } +} + +func TestNilIfZeroNegative(t *testing.T) { + if nilIfZero(-1) != int64(-1) { + t.Fatal("expected -1") + } +} + +// ─── toStringSlice ─────────────────────────────────────────────────── + +func TestToStringSliceFromAnySlice(t *testing.T) { + input := []any{"a", "b", "c"} + got, ok := toStringSlice(input) + if !ok || len(got) != 3 { + t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok) + } + if got[0] != "a" || got[1] != "b" || got[2] != "c" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromMixed(t *testing.T) { + input := []any{"hello", 42, true} + got, ok := toStringSlice(input) + if !ok { + t.Fatal("expected ok for mixed types") + } + if got[0] != "hello" || got[1] != "42" || got[2] != "true" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromNonSlice(t *testing.T) { + _, ok := toStringSlice("not a slice") + if ok { + t.Fatal("expected not ok for string input") + } +} + +func TestToStringSliceFromNil(t *testing.T) { + _, ok := toStringSlice(nil) + if ok { + t.Fatal("expected not ok for nil input") + } +} + +func TestToStringSliceEmpty(t *testing.T) { + got, ok := toStringSlice([]any{}) + if !ok { + t.Fatal("expected ok for empty slice") + } + if len(got) != 0 { + t.Fatalf("expected empty result, got %#v", got) + } +} + +func TestToStringSliceTrimsWhitespace(t *testing.T) { + got, ok := toStringSlice([]any{" hello ", " world "}) + if !ok { + t.Fatal("expected ok") + } + if got[0] != "hello" || got[1] != "world" { + t.Fatalf("expected trimmed values, got %#v", got) + } +} + +// ─── toAccount edge cases ──────────────────────────────────────────── + +func TestToAccountAllFields(t *testing.T) { + acc := toAccount(map[string]any{ + "email": "user@test.com", + "mobile": "13800138000", + "password": "secret", + "token": "tok123", + }) + if acc.Email != "user@test.com" { + t.Fatalf("unexpected email: %q", acc.Email) + } + if acc.Mobile != "13800138000" { + t.Fatalf("unexpected mobile: %q", acc.Mobile) + } + if acc.Password != "secret" { + t.Fatalf("unexpected password: %q", acc.Password) + } + if acc.Token != "tok123" { + t.Fatalf("unexpected token: %q", acc.Token) + } +} + +func TestToAccountNumericValues(t *testing.T) { + acc := toAccount(map[string]any{ + "email": 12345, + }) + if acc.Email != "12345" { + t.Fatalf("expected numeric converted to string, got %q", acc.Email) + } +} + +// ─── fieldString edge cases ────────────────────────────────────────── + +func TestFieldStringNonString(t *testing.T) { + got := fieldString(map[string]any{"key": 42}, "key") + if got != "42" { + t.Fatalf("expected '42' for int, got %q", got) + } +} + +func TestFieldStringBool(t *testing.T) { + got := fieldString(map[string]any{"key": true}, "key") + if got != "true" { + t.Fatalf("expected 'true', got %q", got) + } +} + +func TestFieldStringWhitespace(t *testing.T) { + got := fieldString(map[string]any{"key": " hello "}, "key") + if got != "hello" { + t.Fatalf("expected trimmed 'hello', got %q", got) + } +} + +// ─── statusOr ──────────────────────────────────────────────────────── + +func TestStatusOrZeroReturnsDefault(t *testing.T) { + if got := statusOr(0, http.StatusOK); got != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, got) + } +} + +func TestStatusOrNonZeroReturnsValue(t *testing.T) { + if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest { + t.Fatalf("expected %d, got %d", http.StatusBadRequest, got) + } +} diff --git a/internal/auth/auth_edge_test.go b/internal/auth/auth_edge_test.go new file mode 100644 index 0000000..55c46ef --- /dev/null +++ b/internal/auth/auth_edge_test.go @@ -0,0 +1,375 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +// ─── extractCallerToken edge cases ─────────────────────────────────── + +func TestExtractCallerTokenBearerPrefix(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer my-token") + if got := extractCallerToken(req); got != "my-token" { + t.Fatalf("expected my-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "BEARER My-Token") + if got := extractCallerToken(req); got != "My-Token" { + t.Fatalf("expected My-Token, got %q", got) + } +} + +func TestExtractCallerTokenBearerEmpty(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer ") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for 'Bearer ', got %q", got) + } +} + +func TestExtractCallerTokenXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "x-api-key-token" { + t.Fatalf("expected x-api-key-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer bearer-token") + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "bearer-token" { + t.Fatalf("expected bearer-token, got %q", got) + } +} + +func TestExtractCallerTokenMissingHeaders(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for missing headers, got %q", got) + } +} + +func TestExtractCallerTokenNonBearerAuth(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Basic abc123") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for Basic auth, got %q", got) + } +} + +// ─── Context helpers ───────────────────────────────────────────────── + +func TestWithAuthAndFromContext(t *testing.T) { + a := &RequestAuth{DeepSeekToken: "test-token"} + ctx := WithAuth(context.Background(), a) + got, ok := FromContext(ctx) + if !ok || got.DeepSeekToken != "test-token" { + t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken) + } +} + +func TestFromContextMissing(t *testing.T) { + _, ok := FromContext(context.Background()) + if ok { + t.Fatal("expected not ok from empty context") + } +} + +// ─── RefreshToken edge cases ───────────────────────────────────────── + +func TestRefreshTokenNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestRefreshTokenEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for empty account ID") + } +} + +func TestRefreshTokenSuccess(t *testing.T) { + r := newTestResolver(t) + // First acquire an account + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + if !r.RefreshToken(context.Background(), a) { + t.Fatal("expected refresh to succeed") + } + if a.DeepSeekToken != "fresh-token" { + t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken) + } +} + +// ─── MarkTokenInvalid edge cases ───────────────────────────────────── + +func TestMarkTokenInvalidNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic, token should be unchanged for non-config + if a.DeepSeekToken != "" { + // Actually it does clear it; that's fine - let's check behavior + } +} + +func TestMarkTokenInvalidEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic +} + +func TestMarkTokenInvalidClearsToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + r.MarkTokenInvalid(a) + if a.DeepSeekToken != "" { + t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken) + } + if a.Account.Token != "" { + t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token) + } +} + +// ─── SwitchAccount edge cases ──────────────────────────────────────── + +func TestSwitchAccountNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.SwitchAccount(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestSwitchAccountNilTriedAccounts(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "new-token", nil + }) + + // First acquire + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + + oldID := a.AccountID + a.TriedAccounts = nil // test nil initialization in SwitchAccount + if !r.SwitchAccount(context.Background(), a) { + t.Fatal("expected switch to succeed") + } + if a.AccountID == oldID { + t.Fatalf("expected different account after switch") + } + r.Release(a) +} + +// ─── Release edge cases ───────────────────────────────────────────── + +func TestReleaseNilAuth(t *testing.T) { + r := newTestResolver(t) + r.Release(nil) // should not panic +} + +func TestReleaseNonConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false} + r.Release(a) // should not panic +} + +func TestReleaseEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: ""} + r.Release(a) // should not panic +} + +// ─── JWT edge cases ────────────────────────────────────────────────── + +func TestVerifyJWTInvalidFormat(t *testing.T) { + _, err := VerifyJWT("not-a-jwt") + if err == nil { + t.Fatal("expected error for invalid JWT format") + } +} + +func TestVerifyJWTInvalidSignature(t *testing.T) { + token, _ := CreateJWT(1) + // Tamper with the signature + parts := splitJWT(token) + if len(parts) == 3 { + tampered := parts[0] + "." + parts[1] + ".invalid_signature" + _, err := VerifyJWT(tampered) + if err == nil { + t.Fatal("expected error for tampered signature") + } + } +} + +func TestVerifyJWTExpired(t *testing.T) { + // Create a token with 0 hours expiry - will use default, so we can't easily test + // Instead test with bad payload + _, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid") + if err == nil { + t.Fatal("expected error for expired/invalid JWT") + } +} + +func TestCreateJWTDefaultExpiry(t *testing.T) { + token, err := CreateJWT(0) // should use default + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + _, err = VerifyJWT(token) + if err != nil { + t.Fatalf("verify jwt failed: %v", err) + } +} + +// ─── VerifyAdminRequest edge cases ─────────────────────────────────── + +func TestVerifyAdminRequestNoHeader(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for missing auth") + } +} + +func TestVerifyAdminRequestEmptyBearer(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer ") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for empty bearer") + } +} + +func TestVerifyAdminRequestWithAdminKey(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "test-admin-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer test-admin-key") + if err := VerifyAdminRequest(req); err != nil { + t.Fatalf("expected admin key accepted: %v", err) + } +} + +func TestVerifyAdminRequestInvalidCredentials(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "correct-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer wrong-key") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for wrong key") + } +} + +func TestVerifyAdminRequestBasicAuth(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Basic abc123") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for Basic auth") + } +} + +// ─── Determine with login failure ──────────────────────────────────── + +func TestDetermineWithLoginFailure(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@test.com","password":"pwd"}] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "", errors.New("login failed") + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + _, err := r.Determine(req) + if err == nil { + t.Fatal("expected error when login fails") + } +} + +// ─── Determine with target account ─────────────────────────────────── + +func TestDetermineWithTargetAccount(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "fresh-token", nil + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + req.Header.Set("X-Ds2-Target-Account", "acc2@test.com") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if a.AccountID != "acc2@test.com" { + t.Fatalf("expected target account acc2, got %q", a.AccountID) + } +} + +// helper +func splitJWT(token string) []string { + result := make([]string, 0, 3) + start := 0 + count := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + result = append(result, token[start:i]) + start = i + 1 + count++ + } + } + result = append(result, token[start:]) + return result +} diff --git a/internal/auth/request.go b/internal/auth/request.go index ea3d7f1..25980cf 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -2,6 +2,8 @@ package auth import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "net/http" "strings" @@ -22,6 +24,7 @@ var ( type RequestAuth struct { UseConfigToken bool DeepSeekToken string + CallerID string AccountID string Account config.Account TriedAccounts map[string]bool @@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { if callerKey == "" { return nil, ErrUnauthorized } + callerID := callerTokenID(callerKey) ctx := req.Context() if !r.Store.HasAPIKey(callerKey) { - return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil + return &RequestAuth{ + UseConfigToken: false, + DeepSeekToken: callerKey, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + }, nil } target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) acc, ok := r.Pool.AcquireWait(ctx, target, nil) @@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { } a := &RequestAuth{ UseConfigToken: true, + CallerID: callerID, AccountID: acc.Identifier(), Account: acc, TriedAccounts: map[string]bool{}, @@ -72,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { return a, nil } +// DetermineCaller resolves caller identity without acquiring any pooled account. +// Use this for local-cache lookup routes that only need tenant isolation. +func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { + callerKey := extractCallerToken(req) + if callerKey == "" { + return nil, ErrUnauthorized + } + callerID := callerTokenID(callerKey) + a := &RequestAuth{ + UseConfigToken: false, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + } + if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) { + a.DeepSeekToken = callerKey + } + return a, nil +} + func WithAuth(ctx context.Context, a *RequestAuth) context.Context { return context.WithValue(ctx, authCtxKey, a) } @@ -158,3 +189,12 @@ func extractCallerToken(req *http.Request) string { } return strings.TrimSpace(req.Header.Get("x-api-key")) } + +func callerTokenID(token string) string { + token = strings.TrimSpace(token) + if token == "" { + return "" + } + sum := sha256.Sum256([]byte(token)) + return "caller:" + hex.EncodeToString(sum[:8]) +} diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index 1d568f3..c292856 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) { if auth.DeepSeekToken != "direct-token" { t.Fatalf("unexpected token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } } func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { @@ -58,6 +61,44 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { if auth.DeepSeekToken != "account-token" { t.Fatalf("unexpected account token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } +} + +func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.DetermineCaller(req) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + if a.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } + if a.UseConfigToken { + t.Fatalf("expected no config-token lease for caller-only auth") + } + if a.AccountID != "" { + t.Fatalf("expected empty account id, got %q", a.AccountID) + } +} + +func TestCallerTokenIDStable(t *testing.T) { + a := callerTokenID("token-a") + b := callerTokenID("token-a") + c := callerTokenID("token-b") + if a == "" || b == "" || c == "" { + t.Fatalf("expected non-empty caller ids") + } + if a != b { + t.Fatalf("expected stable caller id, got %q and %q", a, b) + } + if a == c { + t.Fatalf("expected different caller id for different tokens") + } } func TestDetermineMissingToken(t *testing.T) { @@ -72,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestDetermineCallerMissingToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + + _, err := r.DetermineCaller(req) + if err == nil { + t.Fatal("expected unauthorized error") + } + if err != ErrUnauthorized { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 691df6d..d391462 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "errors" + "fmt" "log/slog" "os" "path/filepath" @@ -61,11 +62,33 @@ type Config struct { Accounts []Account `json:"accounts,omitempty"` ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` + ModelAliases map[string]string `json:"model_aliases,omitempty"` + Compat CompatConfig `json:"compat,omitempty"` + Toolcall ToolcallConfig `json:"toolcall,omitempty"` + Responses ResponsesConfig `json:"responses,omitempty"` + Embeddings EmbeddingsConfig `json:"embeddings,omitempty"` VercelSyncHash string `json:"_vercel_sync_hash,omitempty"` VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"` AdditionalFields map[string]any `json:"-"` } +type CompatConfig struct { + WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` +} + +type ToolcallConfig struct { + Mode string `json:"mode,omitempty"` + EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"` +} + +type ResponsesConfig struct { + StoreTTLSeconds int `json:"store_ttl_seconds,omitempty"` +} + +type EmbeddingsConfig struct { + Provider string `json:"provider,omitempty"` +} + func (c Config) MarshalJSON() ([]byte, error) { m := map[string]any{} for k, v := range c.AdditionalFields { @@ -83,6 +106,21 @@ func (c Config) MarshalJSON() ([]byte, error) { if len(c.ClaudeModelMap) > 0 { m["claude_model_mapping"] = c.ClaudeModelMap } + if len(c.ModelAliases) > 0 { + m["model_aliases"] = c.ModelAliases + } + if c.Compat.WideInputStrictOutput != nil { + m["compat"] = c.Compat + } + if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { + m["toolcall"] = c.Toolcall + } + if c.Responses.StoreTTLSeconds > 0 { + m["responses"] = c.Responses + } + if strings.TrimSpace(c.Embeddings.Provider) != "" { + m["embeddings"] = c.Embeddings + } if c.VercelSyncHash != "" { m["_vercel_sync_hash"] = c.VercelSyncHash } @@ -101,17 +139,49 @@ func (c *Config) UnmarshalJSON(b []byte) error { for k, v := range raw { switch k { case "keys": - _ = json.Unmarshal(v, &c.Keys) + if err := json.Unmarshal(v, &c.Keys); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "accounts": - _ = json.Unmarshal(v, &c.Accounts) + if err := json.Unmarshal(v, &c.Accounts); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "claude_mapping": - _ = json.Unmarshal(v, &c.ClaudeMapping) + if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "claude_model_mapping": - _ = json.Unmarshal(v, &c.ClaudeModelMap) + if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "model_aliases": + if err := json.Unmarshal(v, &c.ModelAliases); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "compat": + if err := json.Unmarshal(v, &c.Compat); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "toolcall": + if err := json.Unmarshal(v, &c.Toolcall); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "responses": + if err := json.Unmarshal(v, &c.Responses); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "embeddings": + if err := json.Unmarshal(v, &c.Embeddings); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "_vercel_sync_hash": - _ = json.Unmarshal(v, &c.VercelSyncHash) + if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "_vercel_sync_time": - _ = json.Unmarshal(v, &c.VercelSyncTime) + if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } default: var anyVal any if err := json.Unmarshal(v, &anyVal); err == nil { @@ -124,10 +194,17 @@ func (c *Config) UnmarshalJSON(b []byte) error { func (c Config) Clone() Config { clone := Config{ - Keys: slices.Clone(c.Keys), - Accounts: slices.Clone(c.Accounts), - ClaudeMapping: cloneStringMap(c.ClaudeMapping), - ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + Keys: slices.Clone(c.Keys), + Accounts: slices.Clone(c.Accounts), + ClaudeMapping: cloneStringMap(c.ClaudeMapping), + ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + ModelAliases: cloneStringMap(c.ModelAliases), + Compat: CompatConfig{ + WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), + }, + Toolcall: c.Toolcall, + Responses: c.Responses, + Embeddings: c.Embeddings, VercelSyncHash: c.VercelSyncHash, VercelSyncTime: c.VercelSyncTime, AdditionalFields: map[string]any{}, @@ -149,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string { return out } +func cloneBoolPtr(in *bool) *bool { + if in == nil { + return nil + } + v := *in + return &v +} + type Store struct { mu sync.RWMutex cfg Config @@ -233,30 +318,94 @@ func loadConfig() (Config, bool, error) { content, err := os.ReadFile(ConfigPath()) if err != nil { + if IsVercel() { + // Vercel one-click deploy may start without a writable/present config file. + // Keep an in-memory config so users can bootstrap via WebUI then sync env. + return Config{}, true, nil + } return Config{}, false, err } var cfg Config if err := json.Unmarshal(content, &cfg); err != nil { return Config{}, false, err } + if IsVercel() { + // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors. + return cfg, true, nil + } return cfg, false, nil } func parseConfigString(raw string) (Config, error) { var cfg Config - if err := json.Unmarshal([]byte(raw), &cfg); err == nil { - return cfg, nil + candidates := []string{raw} + if normalized := normalizeConfigInput(raw); normalized != raw { + candidates = append(candidates, normalized) } - decoded, err := base64.StdEncoding.DecodeString(raw) + for _, candidate := range candidates { + if err := json.Unmarshal([]byte(candidate), &cfg); err == nil { + return cfg, nil + } + } + + base64Input := candidates[len(candidates)-1] + decoded, err := decodeConfigBase64(base64Input) if err != nil { - return Config{}, err + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err) } if err := json.Unmarshal(decoded, &cfg); err != nil { - return Config{}, err + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err) } return cfg, nil } +func normalizeConfigInput(raw string) string { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return normalized + } + for { + changed := false + if len(normalized) >= 2 { + first := normalized[0] + last := normalized[len(normalized)-1] + if (first == '"' && last == '"') || (first == '\'' && last == '\'') { + normalized = strings.TrimSpace(normalized[1 : len(normalized)-1]) + changed = true + } + } + if strings.HasPrefix(strings.ToLower(normalized), "base64:") { + normalized = strings.TrimSpace(normalized[len("base64:"):]) + changed = true + } + if !changed { + break + } + } + return strings.TrimSpace(normalized) +} + +func decodeConfigBase64(raw string) ([]byte, error) { + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, enc := range encodings { + decoded, err := enc.DecodeString(raw) + if err == nil { + return decoded, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("base64 decode failed") +} + func (s *Store) Snapshot() Config { s.mu.RLock() defer s.mu.RUnlock() @@ -413,3 +562,62 @@ func (s *Store) ClaudeMapping() map[string]string { } return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} } + +func (s *Store) ModelAliases() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + out := DefaultModelAliases() + for k, v := range s.cfg.ModelAliases { + key := strings.TrimSpace(lower(k)) + val := strings.TrimSpace(lower(v)) + if key == "" || val == "" { + continue + } + out[key] = val + } + return out +} + +func (s *Store) CompatWideInputStrictOutput() bool { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Compat.WideInputStrictOutput == nil { + return true + } + return *s.cfg.Compat.WideInputStrictOutput +} + +func (s *Store) ToolcallMode() string { + s.mu.RLock() + defer s.mu.RUnlock() + mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode)) + if mode == "" { + return "feature_match" + } + return mode +} + +func (s *Store) ToolcallEarlyEmitConfidence() string { + s.mu.RLock() + defer s.mu.RUnlock() + level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence)) + if level == "" { + return "high" + } + return level +} + +func (s *Store) ResponsesStoreTTLSeconds() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Responses.StoreTTLSeconds > 0 { + return s.cfg.Responses.StoreTTLSeconds + } + return 900 +} + +func (s *Store) EmbeddingsProvider() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Embeddings.Provider) +} diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go new file mode 100644 index 0000000..1138867 --- /dev/null +++ b/internal/config/config_edge_test.go @@ -0,0 +1,478 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" +) + +// ─── GetModelConfig edge cases ─────────────────────────────────────── + +func TestGetModelConfigDeepSeekChat(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat") + if !ok { + t.Fatal("expected ok for deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasoner(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner") + if !ok { + t.Fatal("expected ok for deepseek-reasoner") + } + if !thinking || search { + t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekChatSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat-search") + if !ok { + t.Fatal("expected ok for deepseek-chat-search") + } + if thinking || !search { + t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner-search") + if !ok { + t.Fatal("expected ok for deepseek-reasoner-search") + } + if !thinking || !search { + t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigCaseInsensitive(t *testing.T) { + thinking, search, ok := GetModelConfig("DeepSeek-Chat") + if !ok { + t.Fatal("expected ok for case-insensitive deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat") + } +} + +func TestGetModelConfigUnknownModel(t *testing.T) { + _, _, ok := GetModelConfig("gpt-4") + if ok { + t.Fatal("expected not ok for unknown model") + } +} + +func TestGetModelConfigEmpty(t *testing.T) { + _, _, ok := GetModelConfig("") + if ok { + t.Fatal("expected not ok for empty model") + } +} + +// ─── lower function ────────────────────────────────────────────────── + +func TestLowerFunction(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello", "hello"}, + {"ALLCAPS", "allcaps"}, + {"already-lower", "already-lower"}, + {"Mixed-CASE-123", "mixed-case-123"}, + {"", ""}, + } + for _, tc := range tests { + got := lower(tc.input) + if got != tc.expected { + t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ──────────────────── + +func TestConfigJSONRoundtrip(t *testing.T) { + cfg := Config{ + Keys: []string{"key1", "key2"}, + Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + }, + VercelSyncHash: "hash123", + VercelSyncTime: 1234567890, + AdditionalFields: map[string]any{ + "custom_field": "custom_value", + }, + } + + data, err := cfg.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var decoded Config + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" { + t.Fatalf("unexpected keys: %#v", decoded.Keys) + } + if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" { + t.Fatalf("unexpected accounts: %#v", decoded.Accounts) + } + if decoded.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping) + } + if decoded.VercelSyncHash != "hash123" { + t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash) + } + if decoded.AdditionalFields["custom_field"] != "custom_value" { + t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields) + } +} + +func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) { + raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}` + var cfg Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if cfg.AdditionalFields["my_custom_field"] != "hello" { + t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields) + } + // number_field should also be preserved + if cfg.AdditionalFields["number_field"] != float64(42) { + t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"]) + } +} + +// ─── Config.Clone ──────────────────────────────────────────────────── + +func TestConfigCloneIsDeepCopy(t *testing.T) { + cfg := Config{ + Keys: []string{"key1"}, + Accounts: []Account{{Email: "user@test.com", Token: "token"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + }, + AdditionalFields: map[string]any{"custom": "value"}, + } + + cloned := cfg.Clone() + + // Modify original + cfg.Keys[0] = "modified" + cfg.Accounts[0].Email = "modified@test.com" + cfg.ClaudeMapping["fast"] = "modified-model" + + // Cloned should not be affected + if cloned.Keys[0] != "key1" { + t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys) + } + if cloned.Accounts[0].Email != "user@test.com" { + t.Fatalf("clone accounts was affected: %#v", cloned.Accounts) + } + if cloned.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping) + } +} + +func TestConfigCloneNilMaps(t *testing.T) { + cfg := Config{ + Keys: []string{"k"}, + Accounts: nil, + } + cloned := cfg.Clone() + if len(cloned.Keys) != 1 { + t.Fatalf("unexpected keys length: %d", len(cloned.Keys)) + } + if cloned.Accounts != nil { + t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts) + } +} + +// ─── Account.Identifier edge cases ─────────────────────────────────── + +func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) { + acc := Account{Mobile: "13800138000", Token: "tok"} + if acc.Identifier() != "13800138000" { + t.Fatalf("expected mobile identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) { + acc := Account{Email: "user@test.com", Mobile: "13800138000"} + if acc.Identifier() != "user@test.com" { + t.Fatalf("expected email identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierEmptyAccount(t *testing.T) { + acc := Account{} + if acc.Identifier() != "" { + t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier()) + } +} + +// ─── normalizeConfigInput ──────────────────────────────────────────── + +func TestNormalizeConfigInputStripsQuotes(t *testing.T) { + got := normalizeConfigInput(`"base64:abc"`) + if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) { + t.Fatalf("expected quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) { + got := normalizeConfigInput("'some-value'") + if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") { + t.Fatalf("expected single quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) { + got := normalizeConfigInput(" hello ") + if got != "hello" { + t.Fatalf("expected trimmed, got %q", got) + } +} + +// ─── parseConfigString edge cases ──────────────────────────────────── + +func TestParseConfigStringPlainJSON(t *testing.T) { + cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["base64-key"],"accounts":[]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString("base64:" + b64) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringInvalidBase64(t *testing.T) { + _, err := parseConfigString("base64:!!!invalid!!!") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestParseConfigStringEmptyString(t *testing.T) { + _, err := parseConfigString("") + if err == nil { + t.Fatal("expected error for empty string") + } +} + +// ─── Store methods ─────────────────────────────────────────────────── + +func TestStoreSnapshotReturnsClone(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`) + store := LoadStore() + snap := store.Snapshot() + snap.Keys[0] = "modified" + if store.Keys()[0] != "k1" { + t.Fatal("snapshot modification should not affect store") + } +} + +func TestStoreHasAPIKeyMultipleKeys(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`) + store := LoadStore() + if !store.HasAPIKey("key1") { + t.Fatal("expected key1 found") + } + if !store.HasAPIKey("key2") { + t.Fatal("expected key2 found") + } + if !store.HasAPIKey("key3") { + t.Fatal("expected key3 found") + } + if store.HasAPIKey("nonexistent") { + t.Fatal("expected nonexistent key not found") + } +} + +func TestStoreFindAccountNotFound(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`) + store := LoadStore() + _, ok := store.FindAccount("nonexistent@test.com") + if ok { + t.Fatal("expected account not found") + } +} + +func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.CompatWideInputStrictOutput() { + t.Fatal("expected default wide_input_strict_output=true when unset") + } +} + +func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`) + store := LoadStore() + if store.CompatWideInputStrictOutput() { + t.Fatal("expected wide_input_strict_output=false when explicitly configured") + } + + snap := store.Snapshot() + data, err := snap.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + rawCompat, ok := out["compat"].(map[string]any) + if !ok { + t.Fatalf("expected compat in marshaled output, got %#v", out) + } + if rawCompat["wide_input_strict_output"] != false { + t.Fatalf("expected explicit false in compat, got %#v", rawCompat) + } +} + +func TestStoreIsEnvBacked(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.IsEnvBacked() { + t.Fatal("expected env-backed store") + } +} + +func TestStoreReplace(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + newCfg := Config{ + Keys: []string{"new-key"}, + Accounts: []Account{{Email: "new@test.com"}}, + } + if err := store.Replace(newCfg); err != nil { + t.Fatalf("replace error: %v", err) + } + if !store.HasAPIKey("new-key") { + t.Fatal("expected new key after replace") + } + if store.HasAPIKey("k1") { + t.Fatal("expected old key removed after replace") + } +} + +func TestStoreUpdate(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + err := store.Update(func(cfg *Config) error { + cfg.Keys = append(cfg.Keys, "k2") + return nil + }) + if err != nil { + t.Fatalf("update error: %v", err) + } + if !store.HasAPIKey("k2") { + t.Fatal("expected k2 after update") + } +} + +func TestStoreClaudeMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := LoadStore() + mapping := store.ClaudeMapping() + if mapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected fast mapping: %q", mapping["fast"]) + } + if mapping["slow"] != "deepseek-reasoner" { + t.Fatalf("unexpected slow mapping: %q", mapping["slow"]) + } +} + +func TestStoreClaudeMappingEmpty(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + mapping := store.ClaudeMapping() + // Even without config mapping, there are defaults + if mapping == nil { + t.Fatal("expected non-nil mapping (may contain defaults)") + } +} + +func TestStoreSetVercelSync(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + if err := store.SetVercelSync("hash123", 1234567890); err != nil { + t.Fatalf("setVercelSync error: %v", err) + } + snap := store.Snapshot() + if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 { + t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime) + } +} + +func TestStoreExportJSONAndBase64(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`) + store := LoadStore() + jsonStr, b64Str, err := store.ExportJSONAndBase64() + if err != nil { + t.Fatalf("export error: %v", err) + } + if !strings.Contains(jsonStr, "export-key") { + t.Fatalf("expected JSON to contain key: %q", jsonStr) + } + decoded, err := base64.StdEncoding.DecodeString(b64Str) + if err != nil { + t.Fatalf("base64 decode error: %v", err) + } + if !strings.Contains(string(decoded), "export-key") { + t.Fatalf("expected base64-decoded to contain key: %q", string(decoded)) + } +} + +// ─── OpenAIModelsResponse / ClaudeModelsResponse ───────────────────── + +func TestOpenAIModelsResponse(t *testing.T) { + resp := OpenAIModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} + +func TestClaudeModelsResponse(t *testing.T) { + resp := ClaudeModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 58a8a2a..a409fd7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/base64" "strings" "testing" ) @@ -70,3 +71,53 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) t.Fatalf("expected find by old identifier alias") } } + +func TestLoadStoreRejectsInvalidFieldType(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":"not-array","accounts":[]}`) + store := LoadStore() + if len(store.Keys()) != 0 || len(store.Accounts()) != 0 { + t.Fatalf("expected empty store when config type is invalid") + } +} + +func TestParseConfigStringSupportsQuotedBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString(`"base64:` + b64 + `"`) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringSupportsRawURLBase64(t *testing.T) { + rawJSON := `{"keys":["k-url"],"accounts":[]}` + b64 := base64.RawURLEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString(b64) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k-url" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) { + t.Setenv("VERCEL", "1") + t.Setenv("DS2API_CONFIG_JSON", "") + t.Setenv("CONFIG_JSON", "") + t.Setenv("DS2API_CONFIG_PATH", "testdata/does-not-exist.json") + + cfg, fromEnv, err := loadConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !fromEnv { + t.Fatalf("expected fromEnv=true for vercel fallback") + } + if len(cfg.Keys) != 0 || len(cfg.Accounts) != 0 { + t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts)) + } +} diff --git a/internal/config/model_alias_test.go b/internal/config/model_alias_test.go new file mode 100644 index 0000000..89e74b0 --- /dev/null +++ b/internal/config/model_alias_test.go @@ -0,0 +1,44 @@ +package config + +import "testing" + +func TestResolveModelDirectDeepSeek(t *testing.T) { + got, ok := ResolveModel(nil, "deepseek-chat") + if !ok || got != "deepseek-chat" { + t.Fatalf("expected deepseek-chat, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelAlias(t *testing.T) { + got, ok := ResolveModel(nil, "gpt-4.1") + if !ok || got != "deepseek-chat" { + t.Fatalf("expected alias gpt-4.1 -> deepseek-chat, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelHeuristicReasoner(t *testing.T) { + got, ok := ResolveModel(nil, "o3-super") + if !ok || got != "deepseek-reasoner" { + t.Fatalf("expected heuristic reasoner, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelUnknown(t *testing.T) { + _, ok := ResolveModel(nil, "totally-custom-model") + if ok { + t.Fatal("expected unknown model to fail resolve") + } +} + +func TestClaudeModelsResponsePaginationFields(t *testing.T) { + resp := ClaudeModelsResponse() + if _, ok := resp["first_id"]; !ok { + t.Fatalf("expected first_id in response: %#v", resp) + } + if _, ok := resp["last_id"]; !ok { + t.Fatalf("expected last_id in response: %#v", resp) + } + if _, ok := resp["has_more"]; !ok { + t.Fatalf("expected has_more in response: %#v", resp) + } +} diff --git a/internal/config/models.go b/internal/config/models.go index 13fa63d..017a2ee 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -1,5 +1,7 @@ package config +import "strings" + type ModelInfo struct { ID string `json:"id"` Object string `json:"object"` @@ -71,6 +73,91 @@ func GetModelConfig(model string) (thinking bool, search bool, ok bool) { } } +func IsSupportedDeepSeekModel(model string) bool { + _, _, ok := GetModelConfig(model) + return ok +} + +func DefaultModelAliases() map[string]string { + return map[string]string{ + "gpt-4o": "deepseek-chat", + "gpt-4.1": "deepseek-chat", + "gpt-4.1-mini": "deepseek-chat", + "gpt-4.1-nano": "deepseek-chat", + "gpt-5": "deepseek-chat", + "gpt-5-mini": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o1": "deepseek-reasoner", + "o1-mini": "deepseek-reasoner", + "o3": "deepseek-reasoner", + "o3-mini": "deepseek-reasoner", + "claude-sonnet-4-5": "deepseek-chat", + "claude-haiku-4-5": "deepseek-chat", + "claude-opus-4-6": "deepseek-reasoner", + "claude-3-5-sonnet": "deepseek-chat", + "claude-3-5-haiku": "deepseek-chat", + "claude-3-opus": "deepseek-reasoner", + "gemini-2.5-pro": "deepseek-chat", + "gemini-2.5-flash": "deepseek-chat", + "llama-3.1-70b-instruct": "deepseek-chat", + "qwen-max": "deepseek-chat", + } +} + +func ResolveModel(store *Store, requested string) (string, bool) { + model := lower(strings.TrimSpace(requested)) + if model == "" { + return "", false + } + if IsSupportedDeepSeekModel(model) { + return model, true + } + aliases := DefaultModelAliases() + if store != nil { + for k, v := range store.ModelAliases() { + aliases[lower(strings.TrimSpace(k))] = lower(strings.TrimSpace(v)) + } + } + if mapped, ok := aliases[model]; ok && IsSupportedDeepSeekModel(mapped) { + return mapped, true + } + if strings.HasPrefix(model, "deepseek-") { + return "", false + } + + knownFamily := false + for _, prefix := range []string{ + "gpt-", "o1", "o3", "claude-", "gemini-", "llama-", "qwen-", "mistral-", "command-", + } { + if strings.HasPrefix(model, prefix) { + knownFamily = true + break + } + } + if !knownFamily { + return "", false + } + + useReasoner := strings.Contains(model, "reason") || + strings.Contains(model, "reasoner") || + strings.HasPrefix(model, "o1") || + strings.HasPrefix(model, "o3") || + strings.Contains(model, "opus") || + strings.Contains(model, "r1") + useSearch := strings.Contains(model, "search") + + switch { + case useReasoner && useSearch: + return "deepseek-reasoner-search", true + case useReasoner: + return "deepseek-reasoner", true + case useSearch: + return "deepseek-chat-search", true + default: + return "deepseek-chat", true + } +} + func lower(s string) string { b := []byte(s) for i, c := range b { @@ -85,6 +172,28 @@ func OpenAIModelsResponse() map[string]any { return map[string]any{"object": "list", "data": DeepSeekModels} } -func ClaudeModelsResponse() map[string]any { - return map[string]any{"object": "list", "data": ClaudeModels} +func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) { + canonical, ok := ResolveModel(store, id) + if !ok { + return ModelInfo{}, false + } + for _, model := range DeepSeekModels { + if model.ID == canonical { + return model, true + } + } + return ModelInfo{}, false +} + +func ClaudeModelsResponse() map[string]any { + resp := map[string]any{"object": "list", "data": ClaudeModels} + if len(ClaudeModels) > 0 { + resp["first_id"] = ClaudeModels[0].ID + resp["last_id"] = ClaudeModels[len(ClaudeModels)-1].ID + } else { + resp["first_id"] = nil + resp["last_id"] = nil + } + resp["has_more"] = false + return resp } diff --git a/internal/deepseek/deepseek_edge_test.go b/internal/deepseek/deepseek_edge_test.go new file mode 100644 index 0000000..92e6952 --- /dev/null +++ b/internal/deepseek/deepseek_edge_test.go @@ -0,0 +1,165 @@ +package deepseek + +import ( + "context" + "testing" +) + +// ─── toFloat64 edge cases ──────────────────────────────────────────── + +func TestToFloat64FromFloat64(t *testing.T) { + if got := toFloat64(float64(3.14), 0); got != 3.14 { + t.Fatalf("expected 3.14, got %f", got) + } +} + +func TestToFloat64FromInt(t *testing.T) { + if got := toFloat64(42, 0); got != 42.0 { + t.Fatalf("expected 42.0, got %f", got) + } +} + +func TestToFloat64FromInt64(t *testing.T) { + if got := toFloat64(int64(100), 0); got != 100.0 { + t.Fatalf("expected 100.0, got %f", got) + } +} + +func TestToFloat64FromStringDefault(t *testing.T) { + if got := toFloat64("42", 99.0); got != 99.0 { + t.Fatalf("expected default 99.0, got %f", got) + } +} + +func TestToFloat64FromNilDefault(t *testing.T) { + if got := toFloat64(nil, 5.5); got != 5.5 { + t.Fatalf("expected default 5.5, got %f", got) + } +} + +func TestToFloat64FromBoolDefault(t *testing.T) { + if got := toFloat64(true, 1.0); got != 1.0 { + t.Fatalf("expected default 1.0, got %f", got) + } +} + +// ─── toInt64 edge cases ────────────────────────────────────────────── + +func TestToInt64FromFloat64(t *testing.T) { + if got := toInt64(float64(42.9), 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt(t *testing.T) { + if got := toInt64(42, 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt64(t *testing.T) { + if got := toInt64(int64(100), 0); got != 100 { + t.Fatalf("expected 100, got %d", got) + } +} + +func TestToInt64FromStringDefault(t *testing.T) { + if got := toInt64("42", 99); got != 99 { + t.Fatalf("expected default 99, got %d", got) + } +} + +func TestToInt64FromNilDefault(t *testing.T) { + if got := toInt64(nil, 7); got != 7 { + t.Fatalf("expected default 7, got %d", got) + } +} + +// ─── BuildPowHeader edge cases ─────────────────────────────────────── + +func TestBuildPowHeaderBasicChallenge(t *testing.T) { + challenge := map[string]any{ + "algorithm": "DeepSeekHashV1", + "challenge": "abc123", + "salt": "salt456", + "signature": "sig789", + "target_path": "/path", + } + result, err := BuildPowHeader(challenge, 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Fatal("expected non-empty result") + } +} + +func TestBuildPowHeaderEmptyChallenge(t *testing.T) { + result, err := BuildPowHeader(map[string]any{}, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should produce a base64 encoded JSON with nil values + if result == "" { + t.Fatal("expected non-empty result for empty challenge") + } +} + +// ─── PowSolver pool size ───────────────────────────────────────────── + +func TestPowPoolSizeFromEnvDefault(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default pool size, got %d", got) + } +} + +func TestPowPoolSizeFromEnvInvalid(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "abc") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default for invalid, got %d", got) + } +} + +func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "5") + got := powPoolSizeFromEnv() + if got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +// ─── NewClient ─────────────────────────────────────────────────────── + +func TestNewClientInitialState(t *testing.T) { + client := NewClient(nil, nil) + if client.powSolver == nil { + t.Fatal("expected powSolver to be initialized") + } +} + +func TestNewClientPreloadPowIdempotent(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + client := NewClient(nil, nil) + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("first preload failed: %v", err) + } + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("second preload failed: %v", err) + } +} + +// ─── PowSolver init and module pool ────────────────────────────────── + +func TestPowSolverPoolSizeMatchesEnv(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "2") + solver := NewPowSolver("test.wasm") + if err := solver.init(context.Background()); err != nil { + t.Fatalf("init failed: %v", err) + } + if cap(solver.pool) != 2 { + t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool)) + } +} diff --git a/internal/server/router.go b/internal/server/router.go index c6339fb..a81f0cb 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -92,7 +92,7 @@ func cors(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go new file mode 100644 index 0000000..8f78f01 --- /dev/null +++ b/internal/sse/consumer_edge_test.go @@ -0,0 +1,140 @@ +package sse + +import ( + "io" + "net/http" + "strings" + "testing" +) + +// ─── CollectStream edge cases ──────────────────────────────────────── + +func makeHTTPResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestCollectStreamEmpty(t *testing.T) { + resp := makeHTTPResponse("") + result := CollectStream(resp, false, false) + if result.Text != "" || result.Thinking != "" { + t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking) + } +} + +func TestCollectStreamTextOnly(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" World\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello World" { + t.Fatalf("expected 'Hello World', got %q", result.Text) + } + if result.Thinking != "" { + t.Fatalf("expected no thinking, got %q", result.Thinking) + } +} + +func TestCollectStreamThinkingAndText(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Thinking..." { + t.Fatalf("expected 'Thinking...', got %q", result.Thinking) + } + if result.Text != "Answer" { + t.Fatalf("expected 'Answer', got %q", result.Text) + } +} + +func TestCollectStreamOnlyThinking(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Only thinking" { + t.Fatalf("expected 'Only thinking', got %q", result.Thinking) + } + if result.Text != "" { + t.Fatalf("expected empty text, got %q", result.Text) + } +} + +func TestCollectStreamSkipsInvalidLines(t *testing.T) { + resp := makeHTTPResponse( + "event: comment\n" + + "data: invalid_json\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "valid" { + t.Fatalf("expected 'valid', got %q", result.Text) + } +} + +func TestCollectStreamWithFragments(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Think" { + t.Fatalf("expected 'Think' thinking, got %q", result.Thinking) + } + if result.Text != "Done" { + t.Fatalf("expected 'Done' text, got %q", result.Text) + } +} + +func TestCollectStreamWithCitation(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" more\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + // CollectStream does NOT filter citations (that's done by the adapters) + // So citations are passed through as-is + if !strings.Contains(result.Text, "[citation:1]") { + t.Fatalf("expected citations to be passed through, got %q", result.Text) + } + if result.Text != "Hello[citation:1] cited text more" { + t.Fatalf("expected full text with citation, got %q", result.Text) + } +} + +func TestCollectStreamMultipleThinkingChunks(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" + + "data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "part1 part2" { + t.Fatalf("expected 'part1 part2', got %q", result.Thinking) + } +} + +func TestCollectStreamStatusFinished(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello" { + t.Fatalf("expected 'Hello', got %q", result.Text) + } +} diff --git a/internal/sse/line_edge_test.go b/internal/sse/line_edge_test.go new file mode 100644 index 0000000..2ae53a6 --- /dev/null +++ b/internal/sse/line_edge_test.go @@ -0,0 +1,70 @@ +package sse + +import "testing" + +func TestParseDeepSeekContentLineNotParsed(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed") + } + if res.NextType != "text" { + t.Fatalf("expected nextType preserved, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking") + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" { + t.Fatalf("unexpected parts: %#v", res.Parts) + } +} + +func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) { + res := ParseDeepSeekContentLine( + []byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`), + true, "thinking", + ) + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if res.NextType != "text" { + t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text") + if !res.ContentFilter { + t.Fatal("expected content filter flag") + } + if res.ErrorMessage == "" { + t.Fatal("expected error message on content filter") + } +} + +func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text") + if !res.Parsed || !res.Stop { + t.Fatalf("expected parsed stop: %#v", res) + } + if res.ErrorMessage == "" { + t.Fatal("expected non-empty error message") + } +} + +func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed for broken JSON") + } +} + +func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) { + res := ParseDeepSeekContentLine([]byte{}, false, "text") + if res.Parsed { + t.Fatal("expected not parsed for empty bytes") + } +} diff --git a/internal/sse/parser_edge_test.go b/internal/sse/parser_edge_test.go new file mode 100644 index 0000000..c851c1f --- /dev/null +++ b/internal/sse/parser_edge_test.go @@ -0,0 +1,631 @@ +package sse + +import "testing" + +// ─── ParseDeepSeekSSELine edge cases ───────────────────────────────── + +func TestParseDeepSeekSSELineEmptyLine(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("")) + if ok { + t.Fatal("expected not parsed for empty line") + } +} + +func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("event: message")) + if ok { + t.Fatal("expected not parsed for non-data line") + } +} + +func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json")) + if ok { + t.Fatal("expected not parsed for invalid JSON") + } +} + +func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte(" ")) + if ok { + t.Fatal("expected not parsed for whitespace-only line") + } +} + +func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) { + chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `)) + if !ok || done { + t.Fatalf("expected parsed chunk for spaced data line") + } + if chunk["v"] != "hello" { + t.Fatalf("unexpected chunk: %#v", chunk) + } +} + +// ─── shouldSkipPath edge cases ─────────────────────────────────────── + +func TestShouldSkipPathQuasiStatus(t *testing.T) { + if !shouldSkipPath("response/quasi_status") { + t.Fatal("expected skip for quasi_status path") + } +} + +func TestShouldSkipPathElapsedSecs(t *testing.T) { + if !shouldSkipPath("response/elapsed_secs") { + t.Fatal("expected skip for elapsed_secs path") + } +} + +func TestShouldSkipPathTokenUsage(t *testing.T) { + if !shouldSkipPath("response/token_usage") { + t.Fatal("expected skip for token_usage path") + } +} + +func TestShouldSkipPathPendingFragment(t *testing.T) { + if !shouldSkipPath("response/pending_fragment") { + t.Fatal("expected skip for pending_fragment path") + } +} + +func TestShouldSkipPathConversationMode(t *testing.T) { + if !shouldSkipPath("response/conversation_mode") { + t.Fatal("expected skip for conversation_mode path") + } +} + +func TestShouldSkipPathSearchStatus(t *testing.T) { + if !shouldSkipPath("response/search_status") { + t.Fatal("expected skip for search_status path") + } +} + +func TestShouldSkipPathFragmentStatus(t *testing.T) { + if !shouldSkipPath("response/fragments/-1/status") { + t.Fatal("expected skip for fragment -1 status") + } + if !shouldSkipPath("response/fragments/-2/status") { + t.Fatal("expected skip for fragment -2 status") + } + if !shouldSkipPath("response/fragments/-3/status") { + t.Fatal("expected skip for fragment -3 status") + } +} + +func TestShouldSkipPathRegularContent(t *testing.T) { + if shouldSkipPath("response/content") { + t.Fatal("expected not skip for content path") + } + if shouldSkipPath("response/thinking_content") { + t.Fatal("expected not skip for thinking_content path") + } +} + +// ─── ParseSSEChunkForContent edge cases ────────────────────────────── + +func TestParseSSEChunkForContentNoVField(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts when v is missing, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected type preserved, got %q", nextType) + } +} + +func TestParseSSEChunkForContentSkippedPath(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{ + "p": "response/token_usage", + "v": "some data", + }, false, "text") + if finished || len(parts) > 0 { + t.Fatalf("expected skipped path to produce no output") + } + if nextType != "text" { + t.Fatalf("expected type preserved for skipped path") + } +} + +func TestParseSSEChunkForContentFinishedStatus(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished, got %d", len(parts)) + } +} + +func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "IN_PROGRESS", + }, false, "text") + if finished { + t.Fatal("expected not finished for non-FINISHED status") + } + if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" { + t.Fatalf("expected content for non-FINISHED status, got %#v", parts) + } +} + +func TestParseSSEChunkForContentEmptyStringV(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/content", + "v": "", + }, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty string v, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on empty path with FINISHED value") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished") + } +} + +func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) { + _, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status path with FINISHED value") + } +} + +func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "some thought", + }, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking part on empty path, got %#v", parts) + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } +} + +func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "text content", + }, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text part when currentType=text, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: fragments path with THINK type ───────── + +func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINK", + "content": "深入思考...", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty content, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "UNKNOWN", + "content": "some text", + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for unknown fragment type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": "not an array", + } + parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + // "not an array" should be treated as string value at the end + if len(parts) != 1 || parts[0].Text != "not an array" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{"string item"}, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + // Non-map items in fragment array are skipped; the []any itself is handled later + _ = parts // just checking it doesn't panic +} + +// ─── ParseSSEChunkForContent: response path with nested fragment ───── + +func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINKING", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "text") + if nextType != "thinking" { + t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType) + } +} + +func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if nextType != "text" { + t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: map value with wrapped response ──────── + +func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "response": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "THINK", + "content": "思考...", + }, + map[string]any{ + "type": "RESPONSE", + "content": "回答...", + }, + }, + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text after RESPONSE, got %q", nextType) + } + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "思考..." { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "回答..." { + t.Fatalf("second part mismatch: %#v", parts[1]) + } +} + +func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "直接回答", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" { + t.Fatalf("unexpected parts for direct fragments: %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "CUSTOM", + "content": "custom content", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected partType fallback for unknown type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty fragment content, got %#v", parts) + } +} + +// ─── ParseSSEChunkForContent: fragments/-1/content path ────────────── + +func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments/-1/content", + "v": "继续思考", + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected inherited thinking type, got %#v", parts) + } +} + +// ─── IsCitation edge cases ─────────────────────────────────────────── + +func TestIsCitationWithLeadingWhitespace(t *testing.T) { + if !IsCitation(" [citation:2] text") { + t.Fatal("expected citation true with leading whitespace") + } +} + +func TestIsCitationEmpty(t *testing.T) { + if IsCitation("") { + t.Fatal("expected citation false for empty string") + } +} + +func TestIsCitationSimilarPrefix(t *testing.T) { + if IsCitation("[cite:1] text") { + t.Fatal("expected citation false for [cite: prefix") + } +} + +// ─── extractContentRecursive edge cases ────────────────────────────── + +func TestExtractContentRecursiveFinishedStatus(t *testing.T) { + items := []any{ + map[string]any{"p": "status", "v": "FINISHED"}, + } + parts, finished := extractContentRecursive(items, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts, got %#v", parts) + } +} + +func TestExtractContentRecursiveSkipsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "token_usage", "v": "data"}, + } + parts, finished := extractContentRecursive(items, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for skipped path, got %#v", parts) + } +} + +func TestExtractContentRecursiveContentField(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking type for THINK content, got %#v", parts) + } +} + +func TestExtractContentRecursiveThinkingPath(t *testing.T) { + items := []any{ + map[string]any{"p": "thinking_content", "v": "deep thought"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" { + t.Fatalf("unexpected parts for thinking path: %#v", parts) + } +} + +func TestExtractContentRecursiveContentPath(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for content path, got %#v", parts) + } +} + +func TestExtractContentRecursiveResponsePath(t *testing.T) { + items := []any{ + map[string]any{"p": "response", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for response path, got %#v", parts) + } +} + +func TestExtractContentRecursiveFragmentsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "fragments", "v": "fragment text"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for fragments path, got %#v", parts) + } +} + +func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "thought", "type": "THINKING"}, + map[string]any{"content": "answer", "type": "RESPONSE"}, + "raw string", + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "thought" { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "answer" { + t.Fatalf("second part mismatch: %#v", parts[1]) + } + if parts[2].Type != "text" || parts[2].Text != "raw string" { + t.Fatalf("third part mismatch: %#v", parts[2]) + } +} + +func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "", "type": "RESPONSE"}, + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty nested content, got %#v", parts) + } +} + +func TestExtractContentRecursiveFinishedString(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "FINISHED"}, + } + parts, _ := extractContentRecursive(items, "text") + // "FINISHED" string value on non-status path should be skipped + if len(parts) != 0 { + t.Fatalf("expected FINISHED string to be skipped, got %#v", parts) + } +} + +func TestExtractContentRecursiveNoVField(t *testing.T) { + items := []any{ + map[string]any{"p": "content"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for missing v field, got %#v", parts) + } +} + +func TestExtractContentRecursiveNonMapItem(t *testing.T) { + items := []any{"just a string", 42} + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for non-map items, got %#v", parts) + } +} diff --git a/internal/sse/stream_edge_test.go b/internal/sse/stream_edge_test.go new file mode 100644 index 0000000..927b023 --- /dev/null +++ b/internal/sse/stream_edge_test.go @@ -0,0 +1,177 @@ +package sse + +import ( + "context" + "io" + "strings" + "testing" +) + +func TestStartParsedLinePumpEmptyBody(t *testing.T) { + body := strings.NewReader("") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) != 0 { + t.Fatalf("expected no results for empty body, got %d", len(collected)) + } +} + +func TestStartParsedLinePumpMultipleLines(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"text\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "thinking") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(collected)) + } + // First should be thinking + if collected[0].Parts[0].Type != "thinking" { + t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type) + } + // Last should be stop + last := collected[len(collected)-1] + if !last.Stop { + t.Fatal("expected last result to be stop") + } +} + +func TestStartParsedLinePumpTypeTracking(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "text") + + types := make([]string, 0) + for r := range results { + for _, p := range r.Parts { + types = append(types, p.Type) + } + } + <-done + + // Should have: thinking, thinking, text, text + expected := []string{"thinking", "thinking", "text", "text"} + if len(types) != len(expected) { + t.Fatalf("expected types %v, got %v", expected, types) + } + for i, want := range expected { + if types[i] != want { + t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types) + } + } +} + +func TestStartParsedLinePumpContextCancellation(t *testing.T) { + pr, pw := io.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + results, done := StartParsedLinePump(ctx, pr, false, "text") + + // Write one line to allow it to start + go func() { + _, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n") + // Don't close yet - wait for context cancel + }() + + // Read first result + r := <-results + if !r.Parsed || len(r.Parts) == 0 { + t.Fatalf("expected first parsed result, got %#v", r) + } + + // Cancel context - this will cause the pump to exit on next send + cancel() + // Close the pipe to unblock scanner.Scan() + pw.Close() + + // Drain remaining results + for range results { + } + + err := <-done + // Error may be context.Canceled or nil (if pipe closed first) + if err != nil && err != context.Canceled { + t.Fatalf("expected context.Canceled or nil error, got %v", err) + } +} + +func TestStartParsedLinePumpOnlyDONE(t *testing.T) { + body := strings.NewReader("data: [DONE]\n") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + <-done + + if len(collected) != 1 { + t.Fatalf("expected 1 result, got %d", len(collected)) + } + if !collected[0].Stop { + t.Fatal("expected stop on [DONE]") + } +} + +func TestStartParsedLinePumpNonSSELines(t *testing.T) { + body := strings.NewReader( + "event: update\n" + + ": comment line\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var validCount int + for r := range results { + if r.Parsed && len(r.Parts) > 0 { + validCount++ + } + } + <-done + + if validCount != 1 { + t.Fatalf("expected 1 valid result, got %d", validCount) + } +} + +func TestStartParsedLinePumpThinkingDisabled(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"response\"}\n" + + "data: [DONE]\n", + ) + // With thinking disabled, thinking content should still be emitted but marked differently + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var parts []ContentPart + for r := range results { + parts = append(parts, r.Parts...) + } + <-done + + if len(parts) < 1 { + t.Fatalf("expected at least 1 part, got %d", len(parts)) + } +} diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go index b48bce5..e6ae9a6 100644 --- a/internal/testsuite/runner.go +++ b/internal/testsuite/runner.go @@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef { {ID: "healthz_ok", Run: r.caseHealthz}, {ID: "readyz_ok", Run: r.caseReadyz}, {ID: "models_openai", Run: r.caseModelsOpenAI}, + {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID}, {ID: "models_claude", Run: r.caseModelsClaude}, {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, {ID: "chat_stream_basic", Run: r.caseChatStream}, + {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream}, + {ID: "responses_stream_basic", Run: r.caseResponsesStream}, + {ID: "embeddings_contract", Run: r.caseEmbeddings}, {ID: "reasoner_stream", Run: r.caseReasonerStream}, {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, {ID: "toolcall_stream", Run: r.caseToolcallStream}, @@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { return nil } +func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) if err != nil { @@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { return nil } +func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请简要回答 hello", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body))) + responseID := asString(m["id"]) + cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body))) + if responseID != "" { + getResp, getErr := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/v1/responses/" + responseID, + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Retryable: true, + }) + if getErr != nil { + return getErr + } + cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode)) + } + return nil +} + +func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请流式回答 hello", + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + hasCreated := false + hasCompleted := false + for _, f := range frames { + switch asString(f["type"]) { + case "response.created": + hasCreated = true + case "response.completed": + hasCompleted = true + } + } + cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/embeddings", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": []string{"hello", "world"}, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + if resp.StatusCode == http.StatusOK { + cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body))) + data, _ := m["data"].([]any) + cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + errObj, _ := m["error"].(map[string]any) + _, hasCode := errObj["code"] + _, hasParam := errObj["param"] + cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { resp, err := cc.request(ctx, requestSpec{ Method: http.MethodPost, diff --git a/internal/util/messages.go b/internal/util/messages.go index 19f2948..fcc9484 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -1,6 +1,8 @@ package util import ( + "encoding/json" + "fmt" "regexp" "strings" @@ -68,15 +70,25 @@ func normalizeContent(v any) string { if !ok { continue } - if m["type"] == "text" { + typeStr, _ := m["type"].(string) + typeStr = strings.ToLower(strings.TrimSpace(typeStr)) + if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { if txt, ok := m["text"].(string); ok { parts = append(parts, txt) + continue + } + if txt, ok := m["content"].(string); ok { + parts = append(parts, txt) } } } return strings.Join(parts, "\n") default: - return "" + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) } } diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go index 30b8cc0..776853b 100644 --- a/internal/util/messages_test.go +++ b/internal/util/messages_test.go @@ -33,6 +33,33 @@ func TestMessagesPrepareRoles(t *testing.T) { } } +func TestMessagesPrepareObjectContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": map[string]any{"temp": 18, "ok": true}}, + } + got := MessagesPrepare(messages) + if !contains(got, `"temp":18`) || !contains(got, `"ok":true`) { + t.Fatalf("expected serialized object content, got %q", got) + } +} + +func TestMessagesPrepareArrayTextVariants(t *testing.T) { + messages := []map[string]any{ + { + "role": "user", + "content": []any{ + map[string]any{"type": "output_text", "text": "line1"}, + map[string]any{"type": "input_text", "text": "line2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + got := MessagesPrepare(messages) + if got != "line1\nline2" { + t.Fatalf("unexpected content from text variants: %q", got) + } +} + func TestConvertClaudeToDeepSeek(t *testing.T) { store := config.LoadStore() req := map[string]any{ diff --git a/internal/util/render.go b/internal/util/render.go new file mode 100644 index 0000000..b5e0a79 --- /dev/null +++ b/internal/util/render.go @@ -0,0 +1,140 @@ +package util + +import ( + "fmt" + "strings" + "time" + + "github.com/google/uuid" +) + +func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + } +} + +func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + exposedOutputText := finalText + output := make([]any, 0, 2) + if len(detected) > 0 { + // Keep structured tool output only; avoid leaking raw tool-call JSON + // into response.output_text for clients reading completed responses. + exposedOutputText = "" + toolCalls := make([]any, 0, len(detected)) + for _, tc := range detected { + toolCalls = append(toolCalls, map[string]any{ + "type": "tool_call", + "name": tc.Name, + "arguments": tc.Input, + }) + } + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": toolCalls, + }) + } else { + content := []any{ + map[string]any{ + "type": "output_text", + "text": finalText, + }, + } + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": exposedOutputText, + "usage": map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + }, + } +} + +func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + content := make([]map[string]any, 0, 4) + if finalThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if finalText == "" { + finalText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": finalText}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), + "output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText), + }, + } +} diff --git a/internal/util/render_stream.go b/internal/util/render_stream.go new file mode 100644 index 0000000..716c158 --- /dev/null +++ b/internal/util/render_stream.go @@ -0,0 +1,93 @@ +package util + +func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { + return map[string]any{ + "delta": delta, + "index": index, + } +} + +func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "index": index, + "finish_reason": finishReason, + } +} + +func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { + out := map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + } + if len(usage) > 0 { + out["usage"] = usage + } + return out +} + +func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + } +} + +func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any { + return map[string]any{ + "type": "response.created", + "id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + } +} + +func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.delta", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any { + return map[string]any{ + "type": "response.completed", + "response": response, + } +} diff --git a/internal/util/render_stream_test.go b/internal/util/render_stream_test.go new file mode 100644 index 0000000..420a311 --- /dev/null +++ b/internal/util/render_stream_test.go @@ -0,0 +1,48 @@ +package util + +import "testing" + +func TestBuildOpenAIChatStreamChunk(t *testing.T) { + chunk := BuildOpenAIChatStreamChunk( + "cid", + 123, + "deepseek-chat", + []map[string]any{BuildOpenAIChatStreamDeltaChoice(0, map[string]any{"role": "assistant"})}, + nil, + ) + if chunk["object"] != "chat.completion.chunk" { + t.Fatalf("unexpected object: %#v", chunk["object"]) + } + choices, _ := chunk["choices"].([]map[string]any) + if len(choices) == 0 { + rawChoices, _ := chunk["choices"].([]any) + if len(rawChoices) == 0 { + t.Fatalf("expected choices") + } + } +} + +func TestBuildOpenAIChatUsage(t *testing.T) { + usage := BuildOpenAIChatUsage("prompt", "think", "answer") + if _, ok := usage["prompt_tokens"]; !ok { + t.Fatalf("expected prompt_tokens") + } + if _, ok := usage["completion_tokens_details"]; !ok { + t.Fatalf("expected completion_tokens_details") + } +} + +func TestBuildOpenAIResponsesEventPayloads(t *testing.T) { + created := BuildOpenAIResponsesCreatedPayload("resp_1", "gpt-4o") + if created["type"] != "response.created" { + t.Fatalf("unexpected type: %#v", created["type"]) + } + done := BuildOpenAIResponsesToolCallDonePayload("resp_1", []map[string]any{{"index": 0}}) + if done["type"] != "response.output_tool_call.done" { + t.Fatalf("unexpected type: %#v", done["type"]) + } + completed := BuildOpenAIResponsesCompletedPayload(map[string]any{"id": "resp_1"}) + if completed["type"] != "response.completed" { + t.Fatalf("unexpected type: %#v", completed["type"]) + } +} diff --git a/internal/util/render_test.go b/internal/util/render_test.go new file mode 100644 index 0000000..9d4feec --- /dev/null +++ b/internal/util/render_test.go @@ -0,0 +1,94 @@ +package util + +import "testing" + +func TestBuildOpenAIChatCompletionWithToolCalls(t *testing.T) { + out := BuildOpenAIChatCompletion( + "cid1", + "deepseek-chat", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["object"] != "chat.completion" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + choices, _ := out["choices"].([]map[string]any) + if len(choices) == 0 { + // json-like map from generic marshalling may be []any in some paths + rawChoices, _ := out["choices"].([]any) + if len(rawChoices) == 0 { + t.Fatalf("expected choices") + } + c0, _ := rawChoices[0].(map[string]any) + if c0["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", c0["finish_reason"]) + } + return + } + if choices[0]["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choices[0]["finish_reason"]) + } +} + +func TestBuildOpenAIResponseObjectWithText(t *testing.T) { + out := BuildOpenAIResponseObject( + "resp_1", + "gpt-4o", + "prompt", + "reasoning", + "text", + nil, + ) + if out["object"] != "response" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + output, _ := out["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected output entries") + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected first output type message, got %#v", first["type"]) + } +} + +func TestBuildOpenAIResponseObjectToolCallsHidesRawOutputText(t *testing.T) { + out := BuildOpenAIResponseObject( + "resp_2", + "gpt-4o", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["output_text"] != "" { + t.Fatalf("expected empty output_text for tool_calls, got %#v", out["output_text"]) + } + output, _ := out["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected output entries") + } + first, _ := output[0].(map[string]any) + if first["type"] != "tool_calls" { + t.Fatalf("expected first output type tool_calls, got %#v", first["type"]) + } +} + +func TestBuildClaudeMessageResponseToolUse(t *testing.T) { + out := BuildClaudeMessageResponse( + "msg_1", + "claude-sonnet-4-5", + []any{map[string]any{"role": "user", "content": "hi"}}, + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["type"] != "message" { + t.Fatalf("unexpected type: %#v", out["type"]) + } + if out["stop_reason"] != "tool_use" { + t.Fatalf("expected stop_reason=tool_use, got %#v", out["stop_reason"]) + } +} diff --git a/internal/util/standard_request.go b/internal/util/standard_request.go new file mode 100644 index 0000000..af73acf --- /dev/null +++ b/internal/util/standard_request.go @@ -0,0 +1,30 @@ +package util + +type StandardRequest struct { + Surface string + RequestedModel string + ResolvedModel string + ResponseModel string + Messages []any + FinalPrompt string + ToolNames []string + Stream bool + Thinking bool + Search bool + PassThrough map[string]any +} + +func (r StandardRequest) CompletionPayload(sessionID string) map[string]any { + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": r.FinalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": r.Thinking, + "search_enabled": r.Search, + } + for k, v := range r.PassThrough { + payload[k] = v + } + return payload +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go index 9b9d4e6..9e44b94 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls.go @@ -10,6 +10,7 @@ import ( var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") +var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```") type ParsedToolCall struct { Name string `json:"name"` @@ -20,6 +21,10 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { if strings.TrimSpace(text) == "" { return nil } + text = stripFencedCodeBlocks(text) + if strings.TrimSpace(text) == "" { + return nil + } candidates := buildToolCallCandidates(text) var parsed []ParsedToolCall @@ -33,6 +38,34 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { return nil } + return filterToolCalls(parsed, availableToolNames) +} + +func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return nil + } + if looksLikeToolExampleContext(trimmed) { + return nil + } + candidates := []string{trimmed} + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") { + continue + } + if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 { + return filterToolCalls(parsed, availableToolNames) + } + } + return nil +} + +func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall { allowed := map[string]struct{}{} for _, name := range availableToolNames { allowed[name] = struct{}{} @@ -283,6 +316,21 @@ func extractJSONObject(text string, start int) (string, int, bool) { return "", 0, false } +func looksLikeToolExampleContext(text string) bool { + t := strings.ToLower(strings.TrimSpace(text)) + if t == "" { + return false + } + return strings.Contains(t, "```") +} + +func stripFencedCodeBlocks(text string) string { + if strings.TrimSpace(text) == "" { + return "" + } + return fencedBlockPattern.ReplaceAllString(text, " ") +} + func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { out := make([]map[string]any, 0, len(calls)) for _, c := range calls { diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 8c44320..f7c82d2 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -19,11 +19,8 @@ func TestParseToolCalls(t *testing.T) { func TestParseToolCallsFromFencedJSON(t *testing.T) { text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```" calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 1 { - t.Fatalf("expected 1 call, got %d", len(calls)) - } - if calls[0].Input["q"] != "news" { - t.Fatalf("unexpected args: %#v", calls[0].Input) + if len(calls) != 0 { + t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) } } @@ -62,3 +59,23 @@ func TestFormatOpenAIToolCalls(t *testing.T) { t.Fatalf("unexpected function name: %#v", fn) } } + +func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) { + mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls) + } + + standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseStandaloneToolCalls(standalone, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected standalone parser to match, got %#v", calls) + } +} + +func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) { + fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```" + if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) + } +} diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go new file mode 100644 index 0000000..cba0ceb --- /dev/null +++ b/internal/util/util_edge_test.go @@ -0,0 +1,429 @@ +package util + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "ds2api/internal/config" +) + +// ─── EstimateTokens edge cases ─────────────────────────────────────── + +func TestEstimateTokensEmpty(t *testing.T) { + if got := EstimateTokens(""); got != 0 { + t.Fatalf("expected 0 for empty string, got %d", got) + } +} + +func TestEstimateTokensShortASCII(t *testing.T) { + got := EstimateTokens("ab") + if got != 1 { + t.Fatalf("expected 1 for 2 ascii chars, got %d", got) + } +} + +func TestEstimateTokensLongASCII(t *testing.T) { + got := EstimateTokens(strings.Repeat("x", 100)) + if got != 25 { + t.Fatalf("expected 25 for 100 ascii chars, got %d", got) + } +} + +func TestEstimateTokensChinese(t *testing.T) { + got := EstimateTokens("你好世界") + if got < 1 { + t.Fatalf("expected at least 1 token for Chinese text, got %d", got) + } +} + +func TestEstimateTokensMixed(t *testing.T) { + got := EstimateTokens("Hello 你好世界") + if got < 2 { + t.Fatalf("expected at least 2 tokens for mixed text, got %d", got) + } +} + +func TestEstimateTokensSingleByte(t *testing.T) { + got := EstimateTokens("x") + if got != 1 { + t.Fatalf("expected 1 for single char (minimum), got %d", got) + } +} + +func TestEstimateTokensSingleChinese(t *testing.T) { + got := EstimateTokens("你") + if got != 1 { + t.Fatalf("expected 1 for single Chinese char, got %d", got) + } +} + +// ─── ToBool edge cases ─────────────────────────────────────────────── + +func TestToBoolTrue(t *testing.T) { + if !ToBool(true) { + t.Fatal("expected true") + } +} + +func TestToBoolFalse(t *testing.T) { + if ToBool(false) { + t.Fatal("expected false") + } +} + +func TestToBoolNonBool(t *testing.T) { + if ToBool("true") { + t.Fatal("expected false for string 'true'") + } + if ToBool(1) { + t.Fatal("expected false for int 1") + } + if ToBool(nil) { + t.Fatal("expected false for nil") + } +} + +// ─── IntFrom edge cases ───────────────────────────────────────────── + +func TestIntFromFloat64(t *testing.T) { + if got := IntFrom(float64(42.5)); got != 42 { + t.Fatalf("expected 42 for float64(42.5), got %d", got) + } +} + +func TestIntFromInt(t *testing.T) { + if got := IntFrom(int(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromInt64(t *testing.T) { + if got := IntFrom(int64(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromString(t *testing.T) { + if got := IntFrom("42"); got != 0 { + t.Fatalf("expected 0 for string, got %d", got) + } +} + +func TestIntFromNil(t *testing.T) { + if got := IntFrom(nil); got != 0 { + t.Fatalf("expected 0 for nil, got %d", got) + } +} + +// ─── WriteJSON ─────────────────────────────────────────────────────── + +func TestWriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + WriteJSON(rec, 200, map[string]any{"key": "value"}) + if rec.Code != 200 { + t.Fatalf("expected 200, got %d", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected application/json content type, got %q", ct) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode error: %v", err) + } + if body["key"] != "value" { + t.Fatalf("unexpected body: %#v", body) + } +} + +func TestWriteJSONStatusCodes(t *testing.T) { + for _, code := range []int{200, 201, 400, 404, 500} { + rec := httptest.NewRecorder() + WriteJSON(rec, code, map[string]any{"status": code}) + if rec.Code != code { + t.Fatalf("expected %d, got %d", code, rec.Code) + } + } +} + +// ─── MessagesPrepare edge cases ────────────────────────────────────── + +func TestMessagesPrepareEmpty(t *testing.T) { + got := MessagesPrepare(nil) + if got != "" { + t.Fatalf("expected empty for nil messages, got %q", got) + } +} + +func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "World"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") { + t.Fatalf("expected both messages, got %q", got) + } + // Should be merged without <|User|> between them + count := strings.Count(got, "<|User|>") + if count != 0 { + t.Fatalf("expected no User marker for first message pair, got %d occurrences", count) + } +} + +func TestMessagesPrepareAssistantMarkers(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "<|Assistant|>") { + t.Fatalf("expected assistant marker, got %q", got) + } + if !strings.Contains(got, "<|end▁of▁sentence|>") { + t.Fatalf("expected end of sentence marker, got %q", got) + } +} + +func TestMessagesPrepareUnknownRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "unknown_role", "content": "Unknown"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Unknown") { + t.Fatalf("expected unknown role content, got %q", got) + } +} + +func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Look at this: ![alt](https://example.com/img.png)"}, + } + got := MessagesPrepare(messages) + if strings.Contains(got, "![alt]") { + t.Fatalf("expected markdown image to be replaced, got %q", got) + } +} + +func TestMessagesPrepareNilContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": nil}, + } + got := MessagesPrepare(messages) + if got != "null" { + t.Logf("nil content handled as: %q", got) + } +} + +// ─── normalizeContent edge cases ───────────────────────────────────── + +func TestNormalizeContentString(t *testing.T) { + got := normalizeContent("hello") + if got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestNormalizeContentArray(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }) + if got != "line1\nline2" { + t.Fatalf("expected 'line1\\nline2', got %q", got) + } +} + +func TestNormalizeContentArrayWithContentField(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "content": "from-content"}, + }) + if got != "from-content" { + t.Fatalf("expected 'from-content', got %q", got) + } +} + +func TestNormalizeContentArraySkipsImage(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"}, + map[string]any{"type": "text", "text": "caption"}, + }) + if strings.Contains(got, "image") { + t.Fatalf("expected image skipped, got %q", got) + } + if got != "caption" { + t.Fatalf("expected 'caption', got %q", got) + } +} + +func TestNormalizeContentArrayNonMapItems(t *testing.T) { + got := normalizeContent([]any{"string item", 42}) + if got != "" { + t.Fatalf("expected empty for non-map items, got %q", got) + } +} + +func TestNormalizeContentJSON(t *testing.T) { + got := normalizeContent(map[string]any{"key": "value"}) + if !strings.Contains(got, `"key":"value"`) { + t.Fatalf("expected JSON serialized, got %q", got) + } +} + +// ─── ConvertClaudeToDeepSeek edge cases ────────────────────────────── + +func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] == "" { + t.Fatal("expected default model") + } +} + +func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "stop_sequences": []any{"\n\n"}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["stop"] == nil { + t.Fatal("expected stop field from stop_sequences") + } +} + +func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "temperature": 0.7, + "top_p": 0.9, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["temperature"] != 0.7 { + t.Fatalf("expected temperature 0.7, got %v", out["temperature"]) + } + if out["top_p"] != 0.9 { + t.Fatalf("expected top_p 0.9, got %v", out["top_p"]) + } +} + +func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + msgs, _ := out["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("expected 1 message without system, got %d", len(msgs)) + } +} + +func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] != "deepseek-reasoner" { + t.Fatalf("expected opus to use slow mapping, got %q", out["model"]) + } +} + +// ─── FormatOpenAIStreamToolCalls ───────────────────────────────────── + +func TestFormatOpenAIStreamToolCalls(t *testing.T) { + formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{ + {Name: "search", Input: map[string]any{"q": "test"}}, + }) + if len(formatted) != 1 { + t.Fatalf("expected 1, got %d", len(formatted)) + } + fn, _ := formatted[0]["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn) + } + if formatted[0]["index"] != 0 { + t.Fatalf("expected index 0, got %v", formatted[0]["index"]) + } +} + +// ─── ParseToolCalls more edge cases ────────────────────────────────── + +func TestParseToolCallsNoToolNames(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseToolCalls(text, nil) + if len(calls) != 1 { + t.Fatalf("expected 1 call with nil tool names, got %d", len(calls)) + } +} + +func TestParseToolCallsEmptyText(t *testing.T) { + calls := ParseToolCalls("", []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected 0 calls for empty text, got %d", len(calls)) + } +} + +func TestParseToolCallsMultipleTools(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}` + calls := ParseToolCalls(text, []string{"search", "get_weather"}) + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %d", len(calls)) + } +} + +func TestParseToolCallsInputAsString(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}` + calls := ParseToolCalls(text, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Input["q"] != "golang" { + t.Fatalf("expected parsed string input, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsWithFunctionWrapper(t *testing.T) { + text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}` + calls := ParseToolCalls(text, []string{"calc"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Name != "calc" { + t.Fatalf("expected calc, got %q", calls[0].Name) + } +} + +func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) { + fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this." + calls := ParseStandaloneToolCalls(fenced, []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected fenced code block ignored, got %d calls", len(calls)) + } +} + +// ─── looksLikeToolExampleContext ───────────────────────────────────── + +func TestLooksLikeToolExampleContextNone(t *testing.T) { + if looksLikeToolExampleContext("I will call the tool now") { + t.Fatal("expected false for non-example context") + } +} + +func TestLooksLikeToolExampleContextFenced(t *testing.T) { + if !looksLikeToolExampleContext("```json") { + t.Fatal("expected true for fenced code block context") + } +} diff --git a/opencode.json.example b/opencode.json.example index 2933e9f..ed18a63 100644 --- a/opencode.json.example +++ b/opencode.json.example @@ -9,6 +9,12 @@ "apiKey": "your-api-key" }, "models": { + "gpt-4o": { + "name": "GPT-4o (aliased to deepseek-chat)" + }, + "gpt-5-codex": { + "name": "GPT-5 Codex (aliased to deepseek-reasoner)" + }, "deepseek-chat": { "name": "DeepSeek Chat (DS2API)" }, @@ -18,5 +24,5 @@ } } }, - "model": "ds2api/deepseek-chat" + "model": "ds2api/gpt-5-codex" } diff --git a/webui/src/components/AccountManager.jsx b/webui/src/components/AccountManager.jsx index 773b84e..7ee3b97 100644 --- a/webui/src/components/AccountManager.jsx +++ b/webui/src/components/AccountManager.jsx @@ -39,6 +39,10 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch const [loadingAccounts, setLoadingAccounts] = useState(false) const apiFetch = authFetch || fetch + const resolveAccountIdentifier = (acc) => { + if (!acc || typeof acc !== 'object') return '' + return String(acc.identifier || acc.email || acc.mobile || '').trim() + } const fetchAccounts = async (targetPage = page) => { setLoadingAccounts(true) @@ -147,9 +151,14 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch } const deleteAccount = async (id) => { + const identifier = String(id || '').trim() + if (!identifier) { + onMessage('error', t('accountManager.invalidIdentifier')) + return + } if (!confirm(t('accountManager.deleteAccountConfirm'))) return try { - const res = await apiFetch(`/admin/accounts/${encodeURIComponent(id)}`, { method: 'DELETE' }) + const res = await apiFetch(`/admin/accounts/${encodeURIComponent(identifier)}`, { method: 'DELETE' }) if (res.ok) { onMessage('success', t('messages.deleted')) fetchAccounts() // 刷新当前页 @@ -163,24 +172,29 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch } const testAccount = async (identifier) => { - setTesting(prev => ({ ...prev, [identifier]: true })) + const accountID = String(identifier || '').trim() + if (!accountID) { + onMessage('error', t('accountManager.invalidIdentifier')) + return + } + setTesting(prev => ({ ...prev, [accountID]: true })) try { const res = await apiFetch('/admin/accounts/test', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ identifier }), + body: JSON.stringify({ identifier: accountID }), }) const data = await res.json() const statusMessage = data.success - ? t('apiTester.testSuccess', { account: identifier, time: data.response_time }) - : `${identifier}: ${data.message}` + ? t('apiTester.testSuccess', { account: accountID, time: data.response_time }) + : `${accountID}: ${data.message}` onMessage(data.success ? 'success' : 'error', statusMessage) fetchAccounts() // 刷新当前页 onRefresh() } catch (e) { onMessage('error', t('accountManager.testFailed', { error: e.message })) } finally { - setTesting(prev => ({ ...prev, [identifier]: false })) + setTesting(prev => ({ ...prev, [accountID]: false })) } } @@ -197,7 +211,12 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch for (let i = 0; i < allAccounts.length; i++) { const acc = allAccounts[i] - const id = acc.email || acc.mobile + const id = resolveAccountIdentifier(acc) + if (!id) { + results.push({ id: '-', success: false, message: t('accountManager.invalidIdentifier') }) + setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] }) + continue + } try { const res = await apiFetch('/admin/accounts/test', { @@ -387,7 +406,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
{t('actions.loading')}
) : accounts.length > 0 ? ( accounts.map((acc, i) => { - const id = acc.email || acc.mobile + const id = resolveAccountIdentifier(acc) return (
@@ -396,7 +415,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch acc.has_token ? "bg-emerald-500 shadow-[0_0_8px_rgba(16,185,129,0.5)]" : "bg-amber-500" )} />
-
{id}
+
{id || '-'}
{acc.has_token ? t('accountManager.sessionActive') : t('accountManager.reauthRequired')} {acc.token_preview && ( @@ -419,7 +438,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch onClick={() => deleteAccount(id)} className="p-1 lg:p-1.5 text-muted-foreground hover:text-destructive hover:bg-destructive/10 rounded-md transition-colors" > - +
diff --git a/webui/src/components/ApiTester.jsx b/webui/src/components/ApiTester.jsx index 7d49982..75af1c0 100644 --- a/webui/src/components/ApiTester.jsx +++ b/webui/src/components/ApiTester.jsx @@ -42,6 +42,10 @@ export default function ApiTester({ config, onMessage, authFetch }) { const apiFetch = authFetch || fetch const accounts = config.accounts || [] + const resolveAccountIdentifier = (acc) => { + if (!acc || typeof acc !== 'object') return '' + return String(acc.identifier || acc.email || acc.mobile || '').trim() + } const configuredKeys = config.keys || [] const trimmedApiKey = apiKey.trim() const defaultKey = configuredKeys[0] || '' @@ -297,11 +301,15 @@ return ( onChange={e => setSelectedAccount(e.target.value)} > - {accounts.map((acc, i) => ( - - ))} + {accounts.map((acc, i) => { + const id = resolveAccountIdentifier(acc) + if (!id) return null + return ( + + ) + })}
diff --git a/webui/src/locales/en.json b/webui/src/locales/en.json index 0daf15f..07610f5 100644 --- a/webui/src/locales/en.json +++ b/webui/src/locales/en.json @@ -86,6 +86,7 @@ "requiredFields": "Password and email/mobile are required.", "deleteKeyConfirm": "Are you sure you want to delete this API key?", "deleteAccountConfirm": "Are you sure you want to delete this account?", + "invalidIdentifier": "Invalid account identifier. Operation aborted.", "testAllConfirm": "Test API connectivity for all accounts?", "testAllCompleted": "Completed: {success}/{total} available", "testFailed": "Test failed: {error}", diff --git a/webui/src/locales/zh.json b/webui/src/locales/zh.json index b405ee4..d0780dd 100644 --- a/webui/src/locales/zh.json +++ b/webui/src/locales/zh.json @@ -86,6 +86,7 @@ "requiredFields": "需要填写密码以及邮箱或手机号", "deleteKeyConfirm": "确定要删除此 API 密钥吗?", "deleteAccountConfirm": "确定要删除此账号吗?", + "invalidIdentifier": "账号标识无效,无法执行操作", "testAllConfirm": "测试所有账号的 API 连通性?", "testAllCompleted": "完成:{success}/{total} 可用", "testFailed": "测试失败: {error}",