Merge pull request #37 from CJackHwang/dev

全渠道适配 工具调用优化 后端优化
This commit is contained in:
CJACK.
2026-02-19 01:16:23 +08:00
committed by GitHub
74 changed files with 8895 additions and 481 deletions

View File

@@ -52,6 +52,9 @@ DS2API_ADMIN_KEY=admin
# Option C: Base64 encoded JSON (recommended for Vercel env var)
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
#
# Generate from local config.json:
# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# ---------------------------------------------------------------
# Paths (optional)

View File

@@ -73,16 +73,6 @@ jobs:
rm -rf "${STAGE}"
done
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
- name: Upload Release Assets
uses: softprops/action-gh-release@v2
with:
files: |
dist/*.tar.gz
dist/*.zip
dist/sha256sums.txt
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
@@ -96,11 +86,19 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Extract Docker metadata
id: meta
id: meta_release
uses: docker/metadata-action@v5
with:
images: ghcr.io/${{ github.repository }}
images: |
ghcr.io/${{ github.repository }}
cjackhwang/ds2api
tags: |
type=raw,value=${{ github.event.release.tag_name }}
type=raw,value=latest
@@ -112,5 +110,36 @@ jobs:
file: ./Dockerfile
push: true
platforms: linux/amd64,linux/arm64
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
tags: ${{ steps.meta_release.outputs.tags }}
labels: ${{ steps.meta_release.outputs.labels }}
- name: Export Docker image archives for release assets
run: |
set -euo pipefail
TAG="${{ github.event.release.tag_name }}"
docker buildx build \
--platform linux/amd64 \
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \
.
docker buildx build \
--platform linux/arm64 \
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \
.
gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar"
gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar"
- name: Generate checksums
run: |
set -euo pipefail
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
- name: Upload Release Assets
uses: softprops/action-gh-release@v2
with:
files: |
dist/*.tar.gz
dist/*.zip
dist/sha256sums.txt

3
.gitignore vendored
View File

@@ -81,6 +81,9 @@ ds2api-tests
htmlcov/
.pytest_cache/
.tox/
*.coverprofile
coverage*.out
cover/
# Misc
*.pyc

157
API.en.md
View File

@@ -9,6 +9,7 @@ This document describes the actual behavior of the current Go codebase.
## Table of Contents
- [Basics](#basics)
- [Configuration Best Practice](#configuration-best-practice)
- [Authentication](#authentication)
- [Route Index](#route-index)
- [Health Endpoints](#health-endpoints)
@@ -27,7 +28,29 @@ This document describes the actual behavior of the current Go codebase.
| Base URL | `http://localhost:5001` or your deployment domain |
| Default Content-Type | `application/json` |
| Health probes | `GET /healthz`, `GET /readyz` |
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) |
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
---
## Configuration Best Practice
Use `config.json` as the single source of truth:
```bash
cp config.example.json config.json
# Edit config.json (keys/accounts)
```
Use it per deployment mode:
- Local run: read `config.json` directly
- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON`
```bash
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
```
For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page.
---
@@ -66,7 +89,11 @@ Two header formats accepted:
| GET | `/healthz` | None | Liveness probe |
| GET | `/readyz` | None | Readiness probe |
| GET | `/v1/models` | None | OpenAI model list |
| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) |
| POST | `/v1/chat/completions` | Business | OpenAI chat completions |
| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) |
| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) |
| POST | `/v1/embeddings` | Business | OpenAI Embeddings API |
| GET | `/anthropic/v1/models` | None | Claude model list |
| POST | `/anthropic/v1/messages` | Business | Claude messages |
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
@@ -127,6 +154,15 @@ No auth required. Returns supported models.
}
```
### Model Alias Resolution
For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy:
1. Match DeepSeek native model IDs first.
2. Then match exact keys in `model_aliases`.
3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.).
4. If still unmatched, return `invalid_request_error`.
### `POST /v1/chat/completions`
**Headers**:
@@ -140,7 +176,7 @@ Content-Type: application/json
| Field | Type | Required | Notes |
| --- | --- | --- | --- |
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) |
| `messages` | array | ✅ | OpenAI-style messages |
| `stream` | boolean | ❌ | Default `false` |
| `tools` | array | ❌ | Function calling schema |
@@ -230,7 +266,63 @@ When `tools` is present, DS2API performs anti-leak handling:
}
```
**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once.
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`.
---
### `GET /v1/models/{id}`
No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model.
### `POST /v1/responses`
OpenAI Responses-style endpoint, accepting either `input` or `messages`.
| Field | Type | Required | Notes |
| --- | --- | --- | --- |
| `model` | string | ✅ | Supports native models + alias mapping |
| `input` | string/array/object | ❌ | One of `input` or `messages` is required |
| `messages` | array | ❌ | One of `input` or `messages` is required |
| `instructions` | string | ❌ | Prepended as a system message |
| `stream` | boolean | ❌ | Default `false` |
| `tools` | array | ❌ | Same tool detection/translation policy as chat |
**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache.
**Stream (SSE)**: minimal event sequence:
```text
event: response.created
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
event: response.output_text.delta
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
event: response.output_tool_call.delta
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
event: response.completed
data: {"type":"response.completed","response":{...}}
data: [DONE]
```
### `GET /v1/responses/{response_id}`
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`).
### `POST /v1/embeddings`
Business auth required. Returns OpenAI-compatible embeddings shape.
| Field | Type | Required | Notes |
| --- | --- | --- | --- |
| `model` | string | ✅ | Supports native models + alias mapping |
| `input` | string/array | ✅ | Supports string, string array, token array |
> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501.
---
@@ -249,7 +341,10 @@ No auth required.
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
]
],
"first_id": "claude-opus-4-6",
"last_id": "claude-instant-1.0",
"has_more": false
}
```
@@ -265,13 +360,15 @@ Content-Type: application/json
anthropic-version: 2023-06-01
```
> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent.
**Request body**:
| Field | Type | Required | Notes |
| --- | --- | --- | --- |
| `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs |
| `messages` | array | ✅ | Claude-style messages |
| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge |
| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge |
| `stream` | boolean | ❌ | Default `false` |
| `system` | string | ❌ | Optional system prompt |
| `tools` | array | ❌ | Claude tool schema |
@@ -416,6 +513,7 @@ Returns sanitized config.
"keys": ["k1", "k2"],
"accounts": [
{
"identifier": "user@example.com",
"email": "user@example.com",
"mobile": "",
"has_password": true,
@@ -476,6 +574,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
{
"items": [
{
"identifier": "user@example.com",
"email": "user@example.com",
"mobile": "",
"has_password": true,
@@ -500,7 +599,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
### `DELETE /admin/accounts/{identifier}`
`identifier` is email or mobile.
`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:<hash>`).
**Response**: `{"success": true, "total_accounts": 5}`
@@ -530,7 +629,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
| Field | Required | Notes |
| --- | --- | --- |
| `identifier` | ✅ | email or mobile |
| `identifier` | ✅ | email / mobile / token-only synthetic id |
| `model` | ❌ | default `deepseek-chat` |
| `message` | ❌ | if empty, only session creation is tested |
@@ -659,13 +758,20 @@ Or manual deploy required:
## Error Payloads
Error formats vary by module:
Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope:
| Module | Format |
| --- | --- |
| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` |
| Claude routes | `{"error": {"type": "...", "message": "..."}}` |
| Admin routes | `{"detail": "..."}` |
```json
{
"error": {
"message": "...",
"type": "invalid_request_error",
"code": "invalid_request",
"param": null
}
}
```
Admin routes keep `{"detail":"..."}`.
Clients should handle HTTP status code plus `error` / `detail` fields.
@@ -707,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \
}'
```
### OpenAI Responses (Stream)
```bash
curl http://localhost:5001/v1/responses \
-H "Authorization: Bearer your-api-key" \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-5-codex",
"input": "Write a hello world in golang",
"stream": true
}'
```
### OpenAI Embeddings
```bash
curl http://localhost:5001/v1/embeddings \
-H "Authorization: Bearer your-api-key" \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4o",
"input": ["first text", "second text"]
}'
```
### OpenAI with Search
```bash

157
API.md
View File

@@ -9,6 +9,7 @@
## 目录
- [基础信息](#基础信息)
- [配置最佳实践](#配置最佳实践)
- [鉴权规则](#鉴权规则)
- [路由总览](#路由总览)
- [健康检查](#健康检查)
@@ -27,7 +28,29 @@
| Base URL | `http://localhost:5001` 或你的部署域名 |
| 默认 Content-Type | `application/json` |
| 健康检查 | `GET /healthz``GET /readyz` |
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization` |
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass` |
---
## 配置最佳实践
推荐把 `config.json` 作为唯一配置源:
```bash
cp config.example.json config.json
# 编辑 config.jsonkeys/accounts
```
按部署方式使用:
- 本地运行:直接读取 `config.json`
- Docker / Vercel`config.json` 生成 Base64填入 `DS2API_CONFIG_JSON`
```bash
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
```
Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。
---
@@ -66,7 +89,11 @@
| GET | `/healthz` | 无 | 存活探针 |
| GET | `/readyz` | 无 | 就绪探针 |
| GET | `/v1/models` | 无 | OpenAI 模型列表 |
| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) |
| POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 |
| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) |
| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response内存 TTL |
| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 |
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
@@ -127,6 +154,15 @@
}
```
### 模型 alias 解析策略
`chat` / `responses` / `embeddings``model` 字段采用“宽进严出”:
1. 先匹配 DeepSeek 原生模型。
2. 再匹配 `model_aliases` 精确映射。
3. 未命中时按模型家族规则回退(如 `o*``gpt-*``claude-*`)。
4. 仍未命中则返回 `invalid_request_error`
### `POST /v1/chat/completions`
**请求头**
@@ -140,7 +176,7 @@ Content-Type: application/json
| 字段 | 类型 | 必填 | 说明 |
| --- | --- | --- | --- |
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias`gpt-4o``gpt-5-codex``o3``claude-sonnet-4-5` |
| `messages` | array | ✅ | OpenAI 风格消息数组 |
| `stream` | boolean | ❌ | 默认 `false` |
| `tools` | array | ❌ | Function Calling 定义 |
@@ -230,7 +266,63 @@ data: [DONE]
}
```
**流式**先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call `index`);否则一次性输出普通文本
**流式**命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`
---
### `GET /v1/models/{id}`
无需鉴权。入参支持 alias例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。
### `POST /v1/responses`
OpenAI Responses 风格接口,兼容 `input``messages`
| 字段 | 类型 | 必填 | 说明 |
| --- | --- | --- | --- |
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
| `input` | string/array/object | ❌ | 与 `messages` 二选一 |
| `messages` | array | ❌ | 与 `input` 二选一 |
| `instructions` | string | ❌ | 自动前置为 system 消息 |
| `stream` | boolean | ❌ | 默认 `false` |
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
**流式响应SSE**:最小事件序列如下。
```text
event: response.created
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
event: response.output_text.delta
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
event: response.output_tool_call.delta
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
event: response.completed
data: {"type":"response.completed","response":{...}}
data: [DONE]
```
### `GET /v1/responses/{response_id}`
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。
> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。
### `POST /v1/embeddings`
需要业务鉴权。返回 OpenAI Embeddings 兼容结构。
| 字段 | 类型 | 必填 | 说明 |
| --- | --- | --- | --- |
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 |
> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构HTTP 501
---
@@ -249,7 +341,10 @@ data: [DONE]
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
]
],
"first_id": "claude-opus-4-6",
"last_id": "claude-instant-1.0",
"has_more": false
}
```
@@ -265,13 +360,15 @@ Content-Type: application/json
anthropic-version: 2023-06-01
```
> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。
**请求体**
| 字段 | 类型 | 必填 | 说明 |
| --- | --- | --- | --- |
| `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID |
| `messages` | array | ✅ | Claude 风格消息数组 |
| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 |
| `max_tokens` | number | ❌ | 缺省自动补 `8192`当前实现不会硬性截断上游输出 |
| `stream` | boolean | ❌ | 默认 `false` |
| `system` | string | ❌ | 可选系统提示 |
| `tools` | array | ❌ | Claude tool 定义 |
@@ -416,6 +513,7 @@ data: {"type":"message_stop"}
"keys": ["k1", "k2"],
"accounts": [
{
"identifier": "user@example.com",
"email": "user@example.com",
"mobile": "",
"has_password": true,
@@ -476,6 +574,7 @@ data: {"type":"message_stop"}
{
"items": [
{
"identifier": "user@example.com",
"email": "user@example.com",
"mobile": "",
"has_password": true,
@@ -500,7 +599,7 @@ data: {"type":"message_stop"}
### `DELETE /admin/accounts/{identifier}`
`identifier` 为 emailmobile。
`identifier` 为 emailmobile,或 token-only 账号的合成标识(`token:<hash>`
**响应**`{"success": true, "total_accounts": 5}`
@@ -530,7 +629,7 @@ data: {"type":"message_stop"}
| 字段 | 必填 | 说明 |
| --- | --- | --- |
| `identifier` | ✅ | email mobile |
| `identifier` | ✅ | email / mobile / token-only 合成标识 |
| `model` | ❌ | 默认 `deepseek-chat` |
| `message` | ❌ | 空字符串时仅测试会话创建 |
@@ -659,13 +758,20 @@ data: {"type":"message_stop"}
## 错误响应格式
不同模块的错误格式略有差异
兼容路由(`/v1/*``/anthropic/*`)统一使用以下结构
| 模块 | 格式 |
| --- | --- |
| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` |
| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` |
| Admin 接口 | `{"detail": "..."}` |
```json
{
"error": {
"message": "...",
"type": "invalid_request_error",
"code": "invalid_request",
"param": null
}
}
```
Admin 接口保持 `{"detail":"..."}`
建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error``detail` 字段。
@@ -707,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \
}'
```
### OpenAI Responses流式
```bash
curl http://localhost:5001/v1/responses \
-H "Authorization: Bearer your-api-key" \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-5-codex",
"input": "写一个 golang 的 hello world",
"stream": true
}'
```
### OpenAI Embeddings
```bash
curl http://localhost:5001/v1/embeddings \
-H "Authorization: Bearer your-api-key" \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4o",
"input": ["第一段文本", "第二段文本"]
}'
```
### OpenAI 带搜索
```bash

View File

@@ -33,6 +33,17 @@ Config source (choose one):
- **File**: `config.json` (recommended for local/Docker)
- **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64)
Unified recommendation (best practice):
```bash
cp config.example.json config.json
# Edit config.json
```
Use `config.json` as the single source of truth:
- Local run: read `config.json` directly
- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it
---
## 1. Local Run
@@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api
### 2.1 Basic Steps
```bash
# Copy and edit environment
# Copy env template
cp .env.example .env
# Edit .env, at minimum set:
# Generate single-line Base64 from config.json
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# Edit .env and set:
# DS2API_ADMIN_KEY=your-admin-key
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# Start
docker-compose up -d
@@ -167,15 +182,49 @@ If container logs look normal but the admin panel is unreachable, check these fi
1. **Fork** the repo to your GitHub account
2. **Import** the project on Vercel
3. **Set environment variables** (at minimum):
3. **Set environment variables** (minimum required: one variable):
| Variable | Description |
| --- | --- |
| `DS2API_ADMIN_KEY` | Admin key (required) |
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) |
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) |
4. **Deploy**
### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes)
If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page.
Recommended: in repo root, copy the template first and fill your real accounts:
```bash
cp config.example.json config.json
# Edit config.json
```
Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it:
```bash
# Run in repo root
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
echo "$DS2API_CONFIG_JSON"
```
If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables:
```text
DS2API_ADMIN_KEY=replace-with-a-strong-secret
DS2API_CONFIG_JSON=<the single-line Base64 output above>
```
Optional but recommended (for WebUI one-click Vercel sync):
```text
VERCEL_TOKEN=your-vercel-token
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts
```
### 3.2 Optional Environment Variables
| Variable | Description | Default |

View File

@@ -33,6 +33,17 @@
- **文件方式**`config.json`(推荐本地/Docker 使用)
- **环境变量方式**`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码)
统一建议(最优实践):
```bash
cp config.example.json config.json
# 编辑 config.json
```
建议把 `config.json` 作为唯一配置源:
- 本地运行:直接读 `config.json`
- Docker / Vercel`config.json` 生成 `DS2API_CONFIG_JSON`Base64注入环境变量
---
## 一、本地运行
@@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api
### 2.1 基本步骤
```bash
# 复制并编辑环境变量
# 复制环境变量模板
cp .env.example .env
# 编辑 .env至少设置
# 从 config.json 生成单行 Base64
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# 编辑 .env请改成你的强密码设置
# DS2API_ADMIN_KEY=your-admin-key
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# 启动
docker-compose up -d
@@ -167,15 +182,49 @@ healthcheck:
1. **Fork 仓库**到你的 GitHub 账号
2. **在 Vercel 上导入项目**
3. **配置环境变量**至少设置以下项):
3. **配置环境变量**最少只需设置以下项):
| 变量 | 说明 |
| --- | --- |
| `DS2API_ADMIN_KEY` | 管理密钥(必填) |
| `DS2API_CONFIG_JSON` | 配置内容JSON 字符串或 Base64 编码(必填 |
| `DS2API_CONFIG_JSON` | 配置内容JSON 字符串或 Base64 编码(可选,建议 |
4. **部署**
### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错)
如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置再在「Vercel 同步」里写回环境变量。
建议先在仓库目录复制示例配置,再按实际账号填写:
```bash
cp config.example.json config.json
# 编辑 config.json
```
不要在 Vercel 面板里手写复杂 JSON建议本地生成 Base64 后粘贴:
```bash
# 在仓库根目录执行
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
echo "$DS2API_CONFIG_JSON"
```
如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置:
```text
DS2API_ADMIN_KEY=请替换为强密码
DS2API_CONFIG_JSON=上一步生成的一整行 Base64
```
可选但推荐(用于 WebUI 一键同步 Vercel 配置):
```text
VERCEL_TOKEN=你的 Vercel Token
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
```
### 3.2 可选环境变量
| 变量 | 说明 | 默认值 |

View File

@@ -54,16 +54,27 @@ flowchart LR
| 能力 | 说明 |
| --- | --- |
| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) |
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` |
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 |
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 |
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
## 平台兼容矩阵
| 级别 | 平台 | 当前状态 |
| --- | --- | --- |
| P0 | Codex CLI/SDK`wire_api=chat` / `wire_api=responses` | ✅ |
| P0 | OpenAI SDKJS/Pythonchat + responses | ✅ |
| P0 | Vercel AI SDKopenai-compatible | ✅ |
| P0 | Anthropic SDKmessages | ✅ |
| P1 | LangChain / LlamaIndex / OpenWebUIOpenAI 兼容接入) | ✅ |
| P2 | MCP 独立桥接层 | 规划中 |
## 模型支持
### OpenAI 接口
@@ -88,6 +99,19 @@ flowchart LR
## 快速开始
### 通用第一步(所有部署方式)
把 `config.json` 作为唯一配置源(推荐做法):
```bash
cp config.example.json config.json
# 编辑 config.json
```
后续部署建议:
- 本地运行:直接读取 `config.json`
- Docker / Vercel由 `config.json` 生成 `DS2API_CONFIG_JSON`Base64注入环境变量
### 方式一:本地运行
**前置要求**Go 1.24+Node.js 20+(仅在需要构建 WebUI 时)
@@ -112,14 +136,20 @@ go run ./cmd/ds2api
### 方式二Docker 运行
```bash
# 1. 配置环境变量
# 1. 准备环境变量文件
cp .env.example .env
# 编辑 .env
# 2. 启动
# 2. 从 config.json 生成 DS2API_CONFIG_JSON单行 Base64
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# 3. 编辑 .env设置
# DS2API_ADMIN_KEY=请替换为强密码
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# 4. 启动
docker-compose up -d
# 3. 查看日志
# 5. 查看日志
docker-compose logs -f
```
@@ -129,9 +159,22 @@ docker-compose logs -f
1. Fork 仓库到自己的 GitHub
2. 在 Vercel 上导入项目
3. 配置环境变量(少设置 `DS2API_ADMIN_KEY` `DS2API_CONFIG_JSON`
3. 配置环境变量(少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`
4. 部署
建议先在仓库目录复制模板并填写:
```bash
cp config.example.json config.json
# 编辑 config.json
```
推荐:先本地把 `config.json` 转成 Base64再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误:
```bash
base64 < config.json | tr -d '\n'
```
> **流式说明**`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`Node Runtime以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。
详细部署说明请参阅 [部署指南](DEPLOY.md)。
@@ -164,6 +207,7 @@ cp opencode.json.example opencode.json
3. 在项目目录启动 OpenCode CLI按你的安装方式运行 `opencode`)。
> 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。
> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`DS2API 两条链路都兼容。
## 配置说明
@@ -184,6 +228,24 @@ cp opencode.json.example opencode.json
"token": ""
}
],
"model_aliases": {
"gpt-4o": "deepseek-chat",
"gpt-5-codex": "deepseek-reasoner",
"o3": "deepseek-reasoner"
},
"compat": {
"wide_input_strict_output": true
},
"toolcall": {
"mode": "feature_match",
"early_emit_confidence": "high"
},
"responses": {
"store_ttl_seconds": 900
},
"embeddings": {
"provider": "deterministic"
},
"claude_model_mapping": {
"fast": "deepseek-chat",
"slow": "deepseek-reasoner"
@@ -194,6 +256,11 @@ cp opencode.json.example opencode.json
- `keys`API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
- `accounts`DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
- `model_aliases`:常见模型名(如 GPT/Codex/Claude到 DeepSeek 模型的映射
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
- `toolcall`:固定采用特征匹配 + 高置信早发策略
- `responses.store_ttl_seconds``/v1/responses/{id}` 的内存缓存 TTL
- `embeddings.provider`embedding 提供方(当前内置 `deterministic/mock/builtin`
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
### 环境变量
@@ -249,10 +316,10 @@ cp opencode.json.example opencode.json
当请求中带 `tools` 时DS2API 会做防泄漏处理:
1. `stream=true` 时先**缓冲**正文片段
2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本
3. 若最终不是工具调用 → 一次性输出普通文本
4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls`
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
## 项目结构

View File

@@ -54,16 +54,27 @@ flowchart LR
| Capability | Details |
| --- | --- |
| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) |
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` |
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` |
| Multi-account rotation | Auto token refresh, email/mobile dual login |
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency |
| Tool Calling | Anti-leak handling: auto buffer, detect, structured output |
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
| Admin API | Config management, account testing/batch test, import/export, Vercel sync |
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
## Platform Compatibility Matrix
| Tier | Platform | Status |
| --- | --- | --- |
| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ |
| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ |
| P0 | Vercel AI SDK (openai-compatible) | ✅ |
| P0 | Anthropic SDK (messages) | ✅ |
| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ |
| P2 | MCP standalone bridge | Planned |
## Model Support
### OpenAI Endpoint
@@ -88,6 +99,19 @@ In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4
## Quick Start
### Universal First Step (all deployment modes)
Use `config.json` as the single source of truth (recommended):
```bash
cp config.example.json config.json
# Edit config.json
```
Recommended per deployment mode:
- Local run: read `config.json` directly
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`
### Option 1: Local Run
**Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally)
@@ -112,14 +136,20 @@ Default URL: `http://localhost:5001`
### Option 2: Docker
```bash
# 1. Configure environment
# 1. Prepare env file
cp .env.example .env
# Edit .env
# 2. Start
# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64)
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# 3. Edit .env and set:
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# 4. Start
docker-compose up -d
# 3. View logs
# 5. View logs
docker-compose logs -f
```
@@ -129,9 +159,22 @@ Rebuild after updates: `docker-compose up -d --build`
1. Fork this repo to your GitHub account
2. Import the project on Vercel
3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`)
3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`)
4. Deploy
Recommended first step in repo root:
```bash
cp config.example.json config.json
# Edit config.json
```
Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes:
```bash
base64 < config.json | tr -d '\n'
```
> **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling.
For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md).
@@ -164,6 +207,7 @@ cp opencode.json.example opencode.json
3. Start OpenCode CLI in the project directory (run `opencode` using your installed method).
> Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example.
> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths.
## Configuration
@@ -184,6 +228,24 @@ cp opencode.json.example opencode.json
"token": ""
}
],
"model_aliases": {
"gpt-4o": "deepseek-chat",
"gpt-5-codex": "deepseek-reasoner",
"o3": "deepseek-reasoner"
},
"compat": {
"wide_input_strict_output": true
},
"toolcall": {
"mode": "feature_match",
"early_emit_confidence": "high"
},
"responses": {
"store_ttl_seconds": 900
},
"embeddings": {
"provider": "deterministic"
},
"claude_model_mapping": {
"fast": "deepseek-chat",
"slow": "deepseek-reasoner"
@@ -194,6 +256,11 @@ cp opencode.json.example opencode.json
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models
- `compat.wide_input_strict_output`: Keep `true` (current default policy)
- `toolcall`: Fixed to feature matching + high-confidence early emit
- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}`
- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in)
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
### Environment Variables
@@ -249,10 +316,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
When `tools` is present in the request, DS2API performs anti-leak handling:
1. With `stream=true`, DS2API **buffers** text deltas first
2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked
3. If no tool call → buffered text is emitted at once
4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
## Project Structure

View File

@@ -1,5 +1,7 @@
'use strict';
const crypto = require('crypto');
const {
extractToolNames,
createToolSieveState,
@@ -83,23 +85,57 @@ module.exports = async function handler(req, res) {
const finalPrompt = asString(prep.body.final_prompt);
const thinkingEnabled = toBool(prep.body.thinking_enabled);
const searchEnabled = toBool(prep.body.search_enabled);
const toolNames = extractToolNames(payload.tools);
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
const toolNames = toolPolicy.toolNames;
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
writeOpenAIError(res, 500, 'invalid vercel prepare response');
return;
}
const releaseLease = createLeaseReleaser(req, leaseID);
const upstreamController = new AbortController();
let clientClosed = false;
let reader = null;
const markClientClosed = () => {
if (clientClosed) {
return;
}
clientClosed = true;
upstreamController.abort();
if (reader && typeof reader.cancel === 'function') {
Promise.resolve(reader.cancel()).catch(() => {});
}
};
const onReqAborted = () => markClientClosed();
const onResClose = () => {
if (!res.writableEnded) {
markClientClosed();
}
};
req.on('aborted', onReqAborted);
res.on('close', onResClose);
try {
const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
method: 'POST',
headers: {
...BASE_HEADERS,
authorization: `Bearer ${deepseekToken}`,
'x-ds-pow-response': powHeader,
},
body: JSON.stringify(completionPayload),
});
let completionRes;
try {
completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
method: 'POST',
headers: {
...BASE_HEADERS,
authorization: `Bearer ${deepseekToken}`,
'x-ds-pow-response': powHeader,
},
body: JSON.stringify(completionPayload),
signal: upstreamController.signal,
});
} catch (err) {
if (clientClosed || isAbortError(err)) {
return;
}
throw err;
}
if (clientClosed) {
return;
}
if (!completionRes.ok || !completionRes.body) {
const detail = await safeReadText(completionRes);
@@ -121,15 +157,20 @@ module.exports = async function handler(req, res) {
let currentType = thinkingEnabled ? 'thinking' : 'text';
let thinkingText = '';
let outputText = '';
const toolSieveEnabled = toolNames.length > 0;
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
const toolSieveState = createToolSieveState();
let toolCallsEmitted = false;
const streamToolCallIDs = new Map();
const decoder = new TextDecoder();
const reader = completionRes.body.getReader();
reader = completionRes.body.getReader();
let buffered = '';
let ended = false;
const sendFrame = (obj) => {
if (clientClosed || res.writableEnded || res.destroyed) {
return;
}
res.write(`data: ${JSON.stringify(obj)}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
@@ -156,6 +197,10 @@ module.exports = async function handler(req, res) {
return;
}
ended = true;
if (clientClosed || res.writableEnded || res.destroyed) {
await releaseLease();
return;
}
const detected = parseToolCalls(outputText, toolNames);
if (detected.length > 0 && !toolCallsEmitted) {
toolCallsEmitted = true;
@@ -179,14 +224,22 @@ module.exports = async function handler(req, res) {
choices: [{ delta: {}, index: 0, finish_reason: reason }],
usage: buildUsage(finalPrompt, thinkingText, outputText),
});
res.write('data: [DONE]\n\n');
if (!res.writableEnded && !res.destroyed) {
res.write('data: [DONE]\n\n');
}
await releaseLease();
res.end();
if (!res.writableEnded && !res.destroyed) {
res.end();
}
};
try {
// eslint-disable-next-line no-constant-condition
while (true) {
if (clientClosed) {
await finish('stop');
return;
}
const { value, done } = await reader.read();
if (done) {
break;
@@ -245,6 +298,14 @@ module.exports = async function handler(req, res) {
}
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
for (const evt of events) {
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
if (!emitEarlyToolDeltas) {
continue;
}
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
continue;
}
if (evt.type === 'tool_calls') {
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
@@ -259,10 +320,16 @@ module.exports = async function handler(req, res) {
}
}
await finish('stop');
} catch (_err) {
} catch (err) {
if (clientClosed || isAbortError(err)) {
await finish('stop');
return;
}
await finish('stop');
}
} finally {
req.removeListener('aborted', onReqAborted);
res.removeListener('close', onResClose);
await releaseLease();
}
};
@@ -345,6 +412,37 @@ function relayPreparedFailure(res, prep) {
writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
}
function resolveToolcallPolicy(prepBody, payloadTools) {
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
return {
toolNames,
toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
emitEarlyToolDeltas,
};
}
function normalizePreparedToolNames(v) {
if (!Array.isArray(v) || v.length === 0) {
return [];
}
const out = [];
for (const item of v) {
const name = asString(item);
if (!name) {
continue;
}
out.push(name);
}
return out;
}
function boolDefaultTrue(v) {
return v !== false;
}
async function safeReadText(resp) {
if (!resp) {
return '';
@@ -656,6 +754,55 @@ function buildUsage(prompt, thinking, output) {
};
}
function formatIncrementalToolCallDeltas(deltas, idStore) {
if (!Array.isArray(deltas) || deltas.length === 0) {
return [];
}
const out = [];
for (const d of deltas) {
if (!d || typeof d !== 'object') {
continue;
}
const index = Number.isInteger(d.index) ? d.index : 0;
const id = ensureStreamToolCallID(idStore, index);
const item = {
index,
id,
type: 'function',
};
const fn = {};
if (asString(d.name)) {
fn.name = asString(d.name);
}
if (typeof d.arguments === 'string' && d.arguments !== '') {
fn.arguments = d.arguments;
}
if (Object.keys(fn).length > 0) {
item.function = fn;
}
out.push(item);
}
return out;
}
function ensureStreamToolCallID(idStore, index) {
const key = Number.isInteger(index) ? index : 0;
const existing = idStore.get(key);
if (existing) {
return existing;
}
const next = `call_${newCallID()}`;
idStore.set(key, next);
return next;
}
function newCallID() {
if (typeof crypto.randomUUID === 'function') {
return crypto.randomUUID().replace(/-/g, '');
}
return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
}
function estimateTokens(text) {
const t = asString(text);
if (!t) {
@@ -667,44 +814,92 @@ function estimateTokens(text) {
async function proxyToGo(req, res, rawBody) {
const url = buildInternalGoURL(req);
const upstream = await fetch(url.toString(), {
method: 'POST',
headers: buildInternalGoHeaders(req, { withContentType: true }),
body: rawBody,
});
res.statusCode = upstream.status;
upstream.headers.forEach((value, key) => {
if (key.toLowerCase() === 'content-length') {
const controller = new AbortController();
let clientClosed = false;
const markClientClosed = () => {
if (clientClosed) {
return;
}
res.setHeader(key, value);
});
clientClosed = true;
controller.abort();
};
const onReqAborted = () => markClientClosed();
const onResClose = () => {
if (!res.writableEnded) {
markClientClosed();
}
};
req.on('aborted', onReqAborted);
res.on('close', onResClose);
if (!upstream.body || typeof upstream.body.getReader !== 'function') {
const bytes = Buffer.from(await upstream.arrayBuffer());
res.end(bytes);
return;
}
const reader = upstream.body.getReader();
try {
// eslint-disable-next-line no-constant-condition
while (true) {
const { value, done } = await reader.read();
if (done) {
break;
let upstream;
try {
upstream = await fetch(url.toString(), {
method: 'POST',
headers: buildInternalGoHeaders(req, { withContentType: true }),
body: rawBody,
signal: controller.signal,
});
} catch (err) {
if (clientClosed || isAbortError(err)) {
if (!res.writableEnded) {
res.end();
}
return;
}
if (value && value.length > 0) {
res.write(Buffer.from(value));
if (typeof res.flush === 'function') {
res.flush();
throw err;
}
if (clientClosed) {
if (!res.writableEnded) {
res.end();
}
return;
}
res.statusCode = upstream.status;
upstream.headers.forEach((value, key) => {
if (key.toLowerCase() === 'content-length') {
return;
}
res.setHeader(key, value);
});
if (!upstream.body || typeof upstream.body.getReader !== 'function') {
const bytes = Buffer.from(await upstream.arrayBuffer());
res.end(bytes);
return;
}
const reader = upstream.body.getReader();
try {
// eslint-disable-next-line no-constant-condition
while (true) {
if (clientClosed) {
break;
}
const { value, done } = await reader.read();
if (done) {
break;
}
if (value && value.length > 0) {
res.write(Buffer.from(value));
if (typeof res.flush === 'function') {
res.flush();
}
}
}
if (!res.writableEnded) {
res.end();
}
} catch (err) {
if (!isAbortError(err) && !res.writableEnded) {
res.end();
}
}
res.end();
} catch (_err) {
} finally {
req.removeListener('aborted', onReqAborted);
res.removeListener('close', onResClose);
if (!res.writableEnded) {
res.end();
}
@@ -762,9 +957,19 @@ function asString(v) {
return String(v).trim();
}
function isAbortError(err) {
if (!err || typeof err !== 'object') {
return false;
}
return err.name === 'AbortError' || err.code === 'ABORT_ERR';
}
module.exports.__test = {
parseChunkForContent,
extractContentRecursive,
shouldSkipPath,
asString,
resolveToolcallPolicy,
normalizePreparedToolNames,
boolDefaultTrue,
};

View File

@@ -10,10 +10,50 @@ const {
flushToolSieve,
} = require('./helpers/stream-tool-sieve');
const { parseChunkForContent } = handler.__test;
const {
parseChunkForContent,
resolveToolcallPolicy,
normalizePreparedToolNames,
boolDefaultTrue,
} = handler.__test;
test('chat-stream exposes parser test hooks', () => {
assert.equal(typeof parseChunkForContent, 'function');
assert.equal(typeof resolveToolcallPolicy, 'function');
});
test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => {
const policy = resolveToolcallPolicy(
{},
[{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }],
);
assert.deepEqual(policy.toolNames, ['read_file']);
assert.equal(policy.toolSieveEnabled, true);
assert.equal(policy.emitEarlyToolDeltas, true);
});
test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => {
const policy = resolveToolcallPolicy(
{
tool_names: [' prepped_tool ', '', null],
toolcall_feature_match: false,
toolcall_early_emit_high: false,
},
[{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }],
);
assert.deepEqual(policy.toolNames, ['prepped_tool']);
assert.equal(policy.toolSieveEnabled, false);
assert.equal(policy.emitEarlyToolDeltas, false);
});
test('normalizePreparedToolNames filters empty values', () => {
assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']);
});
test('boolDefaultTrue keeps false only when explicitly false', () => {
assert.equal(boolDefaultTrue(false), false);
assert.equal(boolDefaultTrue(true), true);
assert.equal(boolDefaultTrue(undefined), true);
});
test('parseChunkForContent keeps split response/content fragments inside response array', () => {
@@ -49,12 +89,13 @@ test('parseChunkForContent + sieve does not leak suspicious prefix in split tool
events.push(...flushToolSieve(state, ['read_file']));
const hasToolCalls = events.some((evt) => evt.type === 'tool_calls' && evt.calls && evt.calls.length > 0);
const hasToolDeltas = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas && evt.deltas.length > 0);
const leakedText = events
.filter((evt) => evt.type === 'text' && evt.text)
.map((evt) => evt.text)
.join('');
assert.equal(hasToolCalls, true);
assert.equal(hasToolCalls || hasToolDeltas, true);
assert.equal(leakedText.includes('{'), false);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
});

View File

@@ -2,6 +2,8 @@
const crypto = require('crypto');
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
function extractToolNames(tools) {
if (!Array.isArray(tools) || tools.length === 0) {
@@ -26,9 +28,25 @@ function createToolSieveState() {
pending: '',
capture: '',
capturing: false,
recentTextTail: '',
toolNameSent: false,
toolName: '',
toolArgsStart: -1,
toolArgsSent: -1,
toolArgsString: false,
toolArgsDone: false,
};
}
function resetIncrementalToolState(state) {
state.toolNameSent = false;
state.toolName = '';
state.toolArgsStart = -1;
state.toolArgsSent = -1;
state.toolArgsString = false;
state.toolArgsDone = false;
}
function processToolSieveChunk(state, chunk, toolNames) {
if (!state) {
return [];
@@ -44,13 +62,27 @@ function processToolSieveChunk(state, chunk, toolNames) {
state.capture += state.pending;
state.pending = '';
}
const consumed = consumeToolCapture(state.capture, toolNames);
const deltas = buildIncrementalToolDeltas(state);
if (deltas.length > 0) {
events.push({ type: 'tool_call_deltas', deltas });
}
const consumed = consumeToolCapture(state, toolNames);
if (!consumed.ready) {
if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
noteText(state, state.capture);
events.push({ type: 'text', text: state.capture });
state.capture = '';
state.capturing = false;
resetIncrementalToolState(state);
continue;
}
break;
}
state.capture = '';
state.capturing = false;
resetIncrementalToolState(state);
if (consumed.prefix) {
noteText(state, consumed.prefix);
events.push({ type: 'text', text: consumed.prefix });
}
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
@@ -70,11 +102,13 @@ function processToolSieveChunk(state, chunk, toolNames) {
if (start >= 0) {
const prefix = state.pending.slice(0, start);
if (prefix) {
noteText(state, prefix);
events.push({ type: 'text', text: prefix });
}
state.capture = state.pending.slice(start);
state.pending = '';
state.capturing = true;
resetIncrementalToolState(state);
continue;
}
@@ -83,6 +117,7 @@ function processToolSieveChunk(state, chunk, toolNames) {
break;
}
state.pending = hold;
noteText(state, safe);
events.push({ type: 'text', text: safe });
}
return events;
@@ -94,24 +129,29 @@ function flushToolSieve(state, toolNames) {
}
const events = processToolSieveChunk(state, '', toolNames);
if (state.capturing) {
const consumed = consumeToolCapture(state.capture, toolNames);
const consumed = consumeToolCapture(state, toolNames);
if (consumed.ready) {
if (consumed.prefix) {
noteText(state, consumed.prefix);
events.push({ type: 'text', text: consumed.prefix });
}
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
events.push({ type: 'tool_calls', calls: consumed.calls });
}
if (consumed.suffix) {
noteText(state, consumed.suffix);
events.push({ type: 'text', text: consumed.suffix });
}
} else if (state.capture) {
// Incomplete captured tool JSON at stream end: suppress raw capture.
noteText(state, state.capture);
events.push({ type: 'text', text: state.capture });
}
state.capture = '';
state.capturing = false;
resetIncrementalToolState(state);
}
if (state.pending) {
noteText(state, state.pending);
events.push({ type: 'text', text: state.pending });
state.pending = '';
}
@@ -151,15 +191,25 @@ function findToolSegmentStart(s) {
return -1;
}
const lower = s.toLowerCase();
const keyIdx = lower.indexOf('tool_calls');
if (keyIdx < 0) {
return -1;
let offset = 0;
// eslint-disable-next-line no-constant-condition
while (true) {
const keyRel = lower.indexOf('tool_calls', offset);
if (keyRel < 0) {
return -1;
}
const keyIdx = keyRel;
const start = s.slice(0, keyIdx).lastIndexOf('{');
const candidateStart = start >= 0 ? start : keyIdx;
if (!insideCodeFence(s.slice(0, candidateStart))) {
return candidateStart;
}
offset = keyIdx + 'tool_calls'.length;
}
const start = s.slice(0, keyIdx).lastIndexOf('{');
return start >= 0 ? start : keyIdx;
}
function consumeToolCapture(captured, toolNames) {
function consumeToolCapture(state, toolNames) {
const captured = state.capture;
if (!captured) {
return { ready: false, prefix: '', calls: [], suffix: '' };
}
@@ -176,25 +226,367 @@ function consumeToolCapture(captured, toolNames) {
if (!obj.ok) {
return { ready: false, prefix: '', calls: [], suffix: '' };
}
const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames);
if (parsed.length === 0) {
// `tool_calls` key exists but strict JSON parse failed.
// Drop the captured object body to avoid leaking raw tool JSON.
const prefixPart = captured.slice(0, start);
const suffixPart = captured.slice(obj.end);
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
return {
ready: true,
prefix: captured.slice(0, start),
prefix: captured,
calls: [],
suffix: captured.slice(obj.end),
suffix: '',
};
}
const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
if (parsed.length === 0) {
if (state.toolNameSent) {
return {
ready: true,
prefix: prefixPart,
calls: [],
suffix: suffixPart,
};
}
return {
ready: true,
prefix: captured,
calls: [],
suffix: '',
};
}
if (state.toolNameSent) {
if (parsed.length > 1) {
return {
ready: true,
prefix: prefixPart,
calls: parsed.slice(1),
suffix: suffixPart,
};
}
return {
ready: true,
prefix: prefixPart,
calls: [],
suffix: suffixPart,
};
}
return {
ready: true,
prefix: captured.slice(0, start),
prefix: prefixPart,
calls: parsed,
suffix: captured.slice(obj.end),
suffix: suffixPart,
};
}
function buildIncrementalToolDeltas(state) {
const captured = state.capture || '';
if (!captured) {
return [];
}
if (looksLikeToolExampleContext(state.recentTextTail)) {
return [];
}
const lower = captured.toLowerCase();
const keyIdx = lower.indexOf('tool_calls');
if (keyIdx < 0) {
return [];
}
const start = captured.slice(0, keyIdx).lastIndexOf('{');
if (start < 0) {
return [];
}
if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
return [];
}
const callStart = findFirstToolCallObjectStart(captured, keyIdx);
if (callStart < 0) {
return [];
}
const deltas = [];
if (!state.toolName) {
const name = extractToolCallName(captured, callStart);
if (!name) {
return [];
}
state.toolName = name;
}
if (state.toolArgsStart < 0) {
const args = findToolCallArgsStart(captured, callStart);
if (args) {
state.toolArgsString = Boolean(args.stringMode);
state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
state.toolArgsSent = state.toolArgsStart;
}
}
if (!state.toolNameSent) {
if (state.toolArgsStart < 0) {
return [];
}
state.toolNameSent = true;
deltas.push({ index: 0, name: state.toolName });
}
if (state.toolArgsStart < 0 || state.toolArgsDone) {
return deltas;
}
const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
if (!progress) {
return deltas;
}
if (progress.end > state.toolArgsSent) {
deltas.push({
index: 0,
arguments: captured.slice(state.toolArgsSent, progress.end),
});
state.toolArgsSent = progress.end;
}
if (progress.complete) {
state.toolArgsDone = true;
}
return deltas;
}
function findFirstToolCallObjectStart(text, keyIdx) {
const arrStart = findToolCallsArrayStart(text, keyIdx);
if (arrStart < 0) {
return -1;
}
const i = skipSpaces(text, arrStart + 1);
if (i >= text.length || text[i] !== '{') {
return -1;
}
return i;
}
function findToolCallsArrayStart(text, keyIdx) {
let i = keyIdx + 'tool_calls'.length;
while (i < text.length && text[i] !== ':') {
i += 1;
}
if (i >= text.length) {
return -1;
}
i = skipSpaces(text, i + 1);
if (i >= text.length || text[i] !== '[') {
return -1;
}
return i;
}
function extractToolCallName(text, callStart) {
let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
if (valueStart < 0 || text[valueStart] !== '"') {
const fnStart = findFunctionObjectStart(text, callStart);
if (fnStart < 0) {
return '';
}
valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
if (valueStart < 0 || text[valueStart] !== '"') {
return '';
}
}
const parsed = parseJSONStringLiteral(text, valueStart);
if (!parsed) {
return '';
}
return parsed.value;
}
function findToolCallArgsStart(text, callStart) {
const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
let valueStart = findObjectFieldValueStart(text, callStart, keys);
if (valueStart < 0) {
const fnStart = findFunctionObjectStart(text, callStart);
if (fnStart < 0) {
return null;
}
valueStart = findObjectFieldValueStart(text, fnStart, keys);
if (valueStart < 0) {
return null;
}
}
if (valueStart >= text.length) {
return null;
}
const ch = text[valueStart];
if (ch === '{' || ch === '[') {
return { start: valueStart, stringMode: false };
}
if (ch === '"') {
return { start: valueStart, stringMode: true };
}
return null;
}
function scanToolCallArgsProgress(text, start, stringMode) {
if (start < 0 || start > text.length) {
return null;
}
if (stringMode) {
let escaped = false;
for (let i = start; i < text.length; i += 1) {
const ch = text[i];
if (escaped) {
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === '"') {
return { end: i, complete: true };
}
}
return { end: text.length, complete: false };
}
if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
return null;
}
let depth = 0;
let quote = '';
let escaped = false;
for (let i = start; i < text.length; i += 1) {
const ch = text[i];
if (quote) {
if (escaped) {
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === quote) {
quote = '';
}
continue;
}
if (ch === '"' || ch === "'") {
quote = ch;
continue;
}
if (ch === '{' || ch === '[') {
depth += 1;
continue;
}
if (ch === '}' || ch === ']') {
depth -= 1;
if (depth === 0) {
return { end: i + 1, complete: true };
}
}
}
return { end: text.length, complete: false };
}
function findObjectFieldValueStart(text, objStart, keys) {
if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') {
return -1;
}
let depth = 0;
let quote = '';
let escaped = false;
for (let i = objStart; i < text.length; i += 1) {
const ch = text[i];
if (quote) {
if (escaped) {
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === quote) {
quote = '';
}
continue;
}
if (ch === '"' || ch === "'") {
if (depth === 1) {
const parsed = parseJSONStringLiteral(text, i);
if (!parsed) {
return -1;
}
let j = skipSpaces(text, parsed.end);
if (j >= text.length || text[j] !== ':') {
i = parsed.end - 1;
continue;
}
j = skipSpaces(text, j + 1);
if (j >= text.length) {
return -1;
}
if (keys.includes(parsed.value)) {
return j;
}
i = j - 1;
continue;
}
quote = ch;
continue;
}
if (ch === '{') {
depth += 1;
continue;
}
if (ch === '}') {
depth -= 1;
if (depth === 0) {
break;
}
}
}
return -1;
}
function findFunctionObjectStart(text, callStart) {
const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
return -1;
}
return valueStart;
}
function parseJSONStringLiteral(text, start) {
if (!text || start < 0 || start >= text.length || text[start] !== '"') {
return null;
}
let out = '';
let escaped = false;
for (let i = start + 1; i < text.length; i += 1) {
const ch = text[i];
if (escaped) {
out += ch;
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === '"') {
return { value: out, end: i + 1 };
}
out += ch;
}
return null;
}
function skipSpaces(text, i) {
let idx = i;
while (idx < text.length) {
const ch = text[idx];
if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') {
idx += 1;
continue;
}
break;
}
return idx;
}
function extractJSONObjectFrom(text, start) {
if (!text || start < 0 || start >= text.length || text[start] !== '{') {
return { ok: false, end: 0 };
@@ -240,7 +632,11 @@ function parseToolCalls(text, toolNames) {
if (!toStringSafe(text)) {
return [];
}
const candidates = buildToolCallCandidates(text);
const sanitized = stripFencedCodeBlocks(text);
if (!toStringSafe(sanitized)) {
return [];
}
const candidates = buildToolCallCandidates(sanitized);
let parsed = [];
for (const c of candidates) {
parsed = parseToolCallsPayload(c);
@@ -251,26 +647,49 @@ function parseToolCalls(text, toolNames) {
if (parsed.length === 0) {
return [];
}
const allowed = new Set((toolNames || []).filter(Boolean));
const out = [];
for (const tc of parsed) {
if (!tc || !tc.name) {
continue;
}
if (allowed.size > 0 && !allowed.has(tc.name)) {
continue;
}
out.push({ name: tc.name, input: tc.input || {} });
return filterToolCalls(parsed, toolNames);
}
function stripFencedCodeBlocks(text) {
const t = typeof text === 'string' ? text : '';
if (!t) {
return '';
}
if (out.length === 0 && parsed.length > 0) {
for (const tc of parsed) {
if (!tc || !tc.name) {
continue;
}
out.push({ name: tc.name, input: tc.input || {} });
return t.replace(/```[\s\S]*?```/g, ' ');
}
function parseStandaloneToolCalls(text, toolNames) {
const trimmed = toStringSafe(text);
if (!trimmed) {
return [];
}
if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
return [];
}
if (looksLikeToolExampleContext(trimmed)) {
return [];
}
const candidates = [trimmed];
if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
if (m && m[1]) {
candidates.push(toStringSafe(m[1]));
}
}
return out;
for (const candidate of candidates) {
const c = toStringSafe(candidate);
if (!c) {
continue;
}
if (!c.startsWith('{') && !c.startsWith('[')) {
continue;
}
const parsed = parseToolCallsPayload(c);
if (parsed.length > 0) {
return filterToolCalls(parsed, toolNames);
}
}
return [];
}
function buildToolCallCandidates(text) {
@@ -432,6 +851,66 @@ function parseToolCallInput(v) {
return {};
}
function filterToolCalls(parsed, toolNames) {
const allowed = new Set((toolNames || []).filter(Boolean));
const out = [];
for (const tc of parsed) {
if (!tc || !tc.name) {
continue;
}
if (allowed.size > 0 && !allowed.has(tc.name)) {
continue;
}
out.push({ name: tc.name, input: tc.input || {} });
}
if (out.length === 0 && parsed.length > 0) {
for (const tc of parsed) {
if (!tc || !tc.name) {
continue;
}
out.push({ name: tc.name, input: tc.input || {} });
}
}
return out;
}
function noteText(state, text) {
if (!state || !hasMeaningfulText(text)) {
return;
}
state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
}
function appendTail(prev, next, max) {
const left = typeof prev === 'string' ? prev : '';
const right = typeof next === 'string' ? next : '';
if (!Number.isFinite(max) || max <= 0) {
return '';
}
const combined = left + right;
if (combined.length <= max) {
return combined;
}
return combined.slice(combined.length - max);
}
function looksLikeToolExampleContext(text) {
return insideCodeFence(text);
}
function insideCodeFence(text) {
const t = typeof text === 'string' ? text : '';
if (!t) {
return false;
}
const ticks = (t.match(/```/g) || []).length;
return ticks % 2 === 1;
}
function hasMeaningfulText(text) {
return toStringSafe(text) !== '';
}
function formatOpenAIStreamToolCalls(calls) {
if (!Array.isArray(calls) || calls.length === 0) {
return [];
@@ -473,5 +952,6 @@ module.exports = {
processToolSieveChunk,
flushToolSieve,
parseToolCalls,
parseStandaloneToolCalls,
formatOpenAIStreamToolCalls,
};

View File

@@ -9,6 +9,7 @@ const {
processToolSieveChunk,
flushToolSieve,
parseToolCalls,
parseStandaloneToolCalls,
} = require('./stream-tool-sieve');
function runSieve(chunks, toolNames) {
@@ -68,9 +69,22 @@ test('parseToolCalls supports fenced json and function.arguments string payload'
'```',
].join('\n');
const calls = parseToolCalls(text, ['read_file']);
assert.equal(calls.length, 1);
assert.equal(calls[0].name, 'read_file');
assert.deepEqual(calls[0].input, { path: 'README.md' });
assert.equal(calls.length, 0);
});
test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => {
const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。';
const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}';
const mixedCalls = parseStandaloneToolCalls(mixed, ['read_file']);
const standaloneCalls = parseStandaloneToolCalls(standalone, ['read_file']);
assert.equal(mixedCalls.length, 0);
assert.equal(standaloneCalls.length, 1);
});
test('parseStandaloneToolCalls ignores fenced code block tool_call examples', () => {
const fenced = ['```json', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', '```'].join('\n');
const calls = parseStandaloneToolCalls(fenced, ['read_file']);
assert.equal(calls.length, 0);
});
test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => {
@@ -84,13 +98,14 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con
);
const leakedText = collectText(events);
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
assert.equal(hasToolCall, true);
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0);
assert.equal(hasToolCall || hasToolDelta, true);
assert.equal(leakedText.includes('{'), false);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
assert.equal(leakedText.includes('后置正文C。'), true);
});
test('sieve drops invalid tool json body while preserving surrounding text', () => {
test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => {
const events = runSieve(
[
'前置正文D。',
@@ -104,18 +119,18 @@ test('sieve drops invalid tool json body while preserving surrounding text', ()
assert.equal(hasToolCall, false);
assert.equal(leakedText.includes('前置正文D。'), true);
assert.equal(leakedText.includes('后置正文E。'), true);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
});
test('sieve suppresses incomplete captured tool json on stream finalize', () => {
test('sieve flushes incomplete captured tool json as text on stream finalize', () => {
const events = runSieve(
['前置正文F。', '{"tool_calls":[{"name":"read_file"'],
['read_file'],
);
const leakedText = collectText(events);
assert.equal(leakedText.includes('前置正文F。'), true);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
assert.equal(leakedText.includes('{'), false);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
assert.equal(leakedText.includes('{'), true);
});
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
@@ -128,3 +143,53 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', ()
assert.equal(hasToolCall, false);
assert.equal(leakedText, '你好,这是普通文本回复。请继续。');
});
test('sieve emits incremental tool_call_deltas for split arguments payload', () => {
const state = createToolSieveState();
const first = processToolSieveChunk(
state,
'{"tool_calls":[{"name":"read_file","input":{"path":"READ',
['read_file'],
);
const second = processToolSieveChunk(
state,
'ME.MD","mode":"head"}}]}',
['read_file'],
);
const tail = flushToolSieve(state, ['read_file']);
const events = [...first, ...second, ...tail];
const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas');
assert.equal(deltaEvents.length > 0, true);
const merged = deltaEvents.flatMap((evt) => evt.deltas || []);
const hasName = merged.some((d) => d.name === 'read_file');
const argsJoined = merged
.map((d) => d.arguments || '')
.join('');
assert.equal(hasName, true);
assert.equal(argsJoined.includes('"path":"README.MD"'), true);
assert.equal(argsJoined.includes('"mode":"head"'), true);
});
test('sieve still intercepts tool call after leading plain text without suffix', () => {
const events = runSieve(
['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],
['read_file'],
);
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
const leakedText = collectText(events);
assert.equal(hasTool, true);
assert.equal(leakedText.includes('我将调用工具。'), true);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
});
test('sieve intercepts tool call and preserves trailing same-chunk text', () => {
const events = runSieve(
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'],
['read_file'],
);
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
const leakedText = collectText(events);
assert.equal(hasTool, true);
assert.equal(leakedText.includes('然后继续解释。'), true);
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
});

View File

@@ -24,5 +24,27 @@
"password": "your-password-3",
"token": ""
}
]
}
],
"model_aliases": {
"gpt-4o": "deepseek-chat",
"gpt-5-codex": "deepseek-reasoner",
"o3": "deepseek-reasoner"
},
"compat": {
"wide_input_strict_output": true
},
"toolcall": {
"mode": "feature_match",
"early_emit_confidence": "high"
},
"responses": {
"store_ttl_seconds": 900
},
"embeddings": {
"provider": "deterministic"
},
"claude_model_mapping": {
"fast": "deepseek-chat",
"slow": "deepseek-reasoner"
}
}

View File

@@ -0,0 +1,249 @@
package account
import (
"context"
"sync"
"testing"
"time"
"ds2api/internal/config"
)
// ─── Pool edge cases ─────────────────────────────────────────────────
func TestPoolEmptyNoAccounts(t *testing.T) {
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2")
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
pool := NewPool(config.LoadStore())
if _, ok := pool.Acquire("", nil); ok {
t.Fatal("expected acquire to fail with no accounts")
}
status := pool.Status()
if total, ok := status["total"].(int); !ok || total != 0 {
t.Fatalf("unexpected total: %#v", status["total"])
}
}
func TestPoolReleaseNonExistentAccount(t *testing.T) {
pool := newPoolForTest(t, "2")
pool.Release("nonexistent@example.com") // should not panic
}
func TestPoolReleaseAlreadyReleased(t *testing.T) {
pool := newPoolForTest(t, "2")
acc, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected acquire success")
}
pool.Release(acc.Identifier())
pool.Release(acc.Identifier()) // double release should not panic
}
func TestPoolAcquireTargetNotFound(t *testing.T) {
pool := newPoolForTest(t, "2")
if _, ok := pool.Acquire("nonexistent@example.com", nil); ok {
t.Fatal("expected acquire to fail for non-existent target")
}
}
func TestPoolAcquireWithExclusionList(t *testing.T) {
pool := newPoolForTest(t, "2")
acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true})
if !ok {
t.Fatal("expected acquire success with exclusion")
}
if acc.Identifier() != "acc2@example.com" {
t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier())
}
pool.Release(acc.Identifier())
}
func TestPoolAcquireAllExcluded(t *testing.T) {
pool := newPoolForTest(t, "2")
if _, ok := pool.Acquire("", map[string]bool{
"acc1@example.com": true,
"acc2@example.com": true,
}); ok {
t.Fatal("expected acquire to fail when all accounts excluded")
}
}
func TestPoolStatusFields(t *testing.T) {
pool := newPoolForTest(t, "2")
status := pool.Status()
// Check all expected fields are present
for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} {
if _, ok := status[key]; !ok {
t.Fatalf("missing status field: %s", key)
}
}
}
func TestPoolStatusAccountDetails(t *testing.T) {
pool := newPoolForTest(t, "2")
acc, _ := pool.Acquire("acc1@example.com", nil)
status := pool.Status()
inUseAccounts, ok := status["in_use_accounts"].([]string)
if !ok {
t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"])
}
found := false
for _, id := range inUseAccounts {
if id == "acc1@example.com" {
found = true
break
}
}
if !found {
t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts)
}
if status["in_use"] != 1 {
t.Fatalf("expected 1 in_use, got %v", status["in_use"])
}
pool.Release(acc.Identifier())
}
func TestPoolAcquireWaitContextCancelled(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
// Exhaust the pool
first, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected first acquire to succeed")
}
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
var waitOK bool
go func() {
defer wg.Done()
_, waitOK = pool.AcquireWait(ctx, "", nil)
}()
// Wait until queued
waitForWaitingCount(t, pool, 1)
// Cancel context
cancel()
wg.Wait()
if waitOK {
t.Fatal("expected acquire to fail after context cancellation")
}
pool.Release(first.Identifier())
}
func TestPoolAcquireWaitTargetAccount(t *testing.T) {
pool := newPoolForTest(t, "1")
// Exhaust acc1
acc1, ok := pool.Acquire("acc1@example.com", nil)
if !ok {
t.Fatal("expected acquire acc1 success")
}
// Acquire acc2 directly (should succeed since acc2 is free)
ctx := context.Background()
acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil)
if !ok {
t.Fatal("expected acquire acc2 success via AcquireWait")
}
if acc2.Identifier() != "acc2@example.com" {
t.Fatalf("expected acc2, got %q", acc2.Identifier())
}
pool.Release(acc1.Identifier())
pool.Release(acc2.Identifier())
}
func TestPoolMaxQueueSizeOverride(t *testing.T) {
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
pool := NewPool(config.LoadStore())
status := pool.Status()
if got, ok := status["max_queue_size"].(int); !ok || got != 5 {
t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"])
}
}
func TestPoolQueueSizeAliasEnv(t *testing.T) {
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7")
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
pool := NewPool(config.LoadStore())
status := pool.Status()
if got, ok := status["max_queue_size"].(int); !ok || got != 7 {
t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"])
}
}
func TestPoolMultipleAcquireReleaseCycles(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
for i := 0; i < 10; i++ {
acc, ok := pool.Acquire("", nil)
if !ok {
t.Fatalf("acquire failed at cycle %d", i)
}
pool.Release(acc.Identifier())
}
}
func TestPoolConcurrentAcquireWait(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
first, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected first acquire success")
}
const waiters = 3
results := make(chan bool, waiters)
for i := 0; i < waiters; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, ok := pool.AcquireWait(ctx, "", nil)
results <- ok
}()
}
// Wait for all to be queued (only 1 can queue)
time.Sleep(50 * time.Millisecond)
// Release and allow queued requests to proceed
pool.Release(first.Identifier())
successCount := 0
timeoutCount := 0
for i := 0; i < waiters; i++ {
select {
case ok := <-results:
if ok {
successCount++
// Release for next waiter
pool.Release("acc1@example.com")
} else {
timeoutCount++
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for results")
}
}
// At least 1 should succeed; 2 may fail due to queue limit
if successCount < 1 {
t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount)
}
}

View File

@@ -0,0 +1,35 @@
package claude
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
rec := httptest.NewRecorder()
writeClaudeError(rec, http.StatusUnauthorized, "bad token")
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rec.Code)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode body: %v", err)
}
errObj, _ := body["error"].(map[string]any)
if errObj["message"] != "bad token" {
t.Fatalf("unexpected message: %v", errObj["message"])
}
if errObj["type"] != "invalid_request_error" {
t.Fatalf("unexpected type: %v", errObj["type"])
}
if errObj["code"] != "authentication_failed" {
t.Fatalf("unexpected code: %v", errObj["code"])
}
if _, ok := errObj["param"]; !ok {
t.Fatal("expected param field")
}
}

View File

@@ -43,6 +43,9 @@ func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
}
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
r.Header.Set("anthropic-version", "2023-06-01")
}
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
@@ -50,132 +53,79 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
writeClaudeError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
norm, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
normalized := normalizeClaudeMessages(messagesRaw)
payload := cloneMap(req)
payload["messages"] = normalized
toolsRequested, _ := req["tools"].([]any)
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
}
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
dsModel, _ := dsPayload["model"].(string)
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
if !ok {
thinkingEnabled = false
searchEnabled = false
}
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
stdReq := norm.Standard
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
return
}
requestPayload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
requestPayload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
return
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
toolNames := extractClaudeToolNames(toolsRequested)
if util.ToBool(req["stream"]) {
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
if stdReq.Stream {
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
fullText := result.Text
fullThinking := result.Thinking
detected := util.ParseToolCalls(fullText, toolNames)
content := make([]map[string]any, 0, 4)
if fullThinking != "" {
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
}
stopReason := "end_turn"
if len(detected) > 0 {
stopReason = "tool_use"
for i, tc := range detected {
content = append(content, map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
"name": tc.Name,
"input": tc.Input,
})
}
} else {
if fullText == "" {
fullText = "抱歉,没有生成有效的响应内容。"
}
content = append(content, map[string]any{"type": "text", "text": fullText})
}
writeJSON(w, http.StatusOK, map[string]any{
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
"type": "message",
"role": "assistant",
"model": model,
"content": content,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
},
})
result := sse.CollectStream(resp, stdReq.Thinking, true)
respBody := util.BuildClaudeMessageResponse(
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
stdReq.ResponseModel,
norm.NormalizedMessages,
result.Thinking,
result.Text,
stdReq.ToolNames,
)
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
writeClaudeError(w, http.StatusUnauthorized, err.Error())
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messages, _ := req["messages"].([]any)
if model == "" || len(messages) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
@@ -206,7 +156,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
@@ -241,6 +191,8 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
"error": map[string]any{
"type": "api_error",
"message": msg,
"code": "internal_error",
"param": nil,
},
})
}
@@ -492,6 +444,28 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
}
}
func writeClaudeError(w http.ResponseWriter, status int, message string) {
code := "invalid_request"
switch status {
case http.StatusUnauthorized:
code = "authentication_failed"
case http.StatusTooManyRequests:
code = "rate_limit_exceeded"
case http.StatusNotFound:
code = "not_found"
case http.StatusInternalServerError:
code = "internal_error"
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"type": "invalid_request_error",
"message": message,
"code": code,
"param": nil,
},
})
}
func normalizeClaudeMessages(messages []any) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {

View File

@@ -0,0 +1,348 @@
package claude
import (
"testing"
)
// ─── normalizeClaudeMessages ─────────────────────────────────────────
func TestNormalizeClaudeMessagesSimpleString(t *testing.T) {
msgs := []any{
map[string]any{"role": "user", "content": "Hello"},
}
got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected 1 message, got %d", len(got))
}
m := got[0].(map[string]any)
if m["content"] != "Hello" {
t.Fatalf("expected 'Hello', got %v", m["content"])
}
}
func TestNormalizeClaudeMessagesArrayContent(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "line1"},
map[string]any{"type": "text", "text": "line2"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "line1\nline2" {
t.Fatalf("expected joined text, got %q", m["content"])
}
}
func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "tool_result", "content": "tool output"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "tool output" {
t.Fatalf("expected 'tool output', got %q", m["content"])
}
}
func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) {
msgs := []any{"not a map", 42}
got := normalizeClaudeMessages(msgs)
if len(got) != 0 {
t.Fatalf("expected 0 messages for non-map items, got %d", len(got))
}
}
func TestNormalizeClaudeMessagesEmpty(t *testing.T) {
got := normalizeClaudeMessages(nil)
if len(got) != 0 {
t.Fatalf("expected 0, got %d", len(got))
}
}
func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) {
msgs := []any{
map[string]any{"role": "assistant", "content": "response"},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["role"] != "assistant" {
t.Fatalf("expected 'assistant', got %q", m["role"])
}
}
func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "Hello"},
map[string]any{"type": "image", "source": "data:..."},
map[string]any{"type": "text", "text": "World"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "Hello\nWorld" {
t.Fatalf("expected only text parts joined, got %q", m["content"])
}
}
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
tools := []any{
map[string]any{
"name": "search",
"description": "Search the web",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
},
}
prompt := buildClaudeToolPrompt(tools)
if prompt == "" {
t.Fatal("expected non-empty prompt")
}
// Should contain tool name and description
if !containsStr(prompt, "search") {
t.Fatalf("expected 'search' in prompt")
}
if !containsStr(prompt, "Search the web") {
t.Fatalf("expected description in prompt")
}
if !containsStr(prompt, "tool_calls") {
t.Fatalf("expected tool_calls instruction in prompt")
}
}
func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
tools := []any{
map[string]any{"name": "tool1", "description": "desc1"},
map[string]any{"name": "tool2", "description": "desc2"},
}
prompt := buildClaudeToolPrompt(tools)
if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") {
t.Fatalf("expected both tools in prompt")
}
}
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
tools := []any{"not a map"}
prompt := buildClaudeToolPrompt(tools)
if prompt == "" {
t.Fatal("expected non-empty prompt even with invalid tools")
}
// Should still contain the intro and instruction
if !containsStr(prompt, "You are Claude") {
t.Fatalf("expected intro in prompt")
}
}
// ─── hasSystemMessage ────────────────────────────────────────────────
func TestHasSystemMessageTrue(t *testing.T) {
msgs := []any{
map[string]any{"role": "system", "content": "You are a helper"},
map[string]any{"role": "user", "content": "Hi"},
}
if !hasSystemMessage(msgs) {
t.Fatal("expected true")
}
}
func TestHasSystemMessageFalse(t *testing.T) {
msgs := []any{
map[string]any{"role": "user", "content": "Hi"},
map[string]any{"role": "assistant", "content": "Hello"},
}
if hasSystemMessage(msgs) {
t.Fatal("expected false")
}
}
func TestHasSystemMessageEmpty(t *testing.T) {
if hasSystemMessage(nil) {
t.Fatal("expected false for nil")
}
}
func TestHasSystemMessageNonMap(t *testing.T) {
msgs := []any{"not a map"}
if hasSystemMessage(msgs) {
t.Fatal("expected false for non-map")
}
}
// ─── extractClaudeToolNames ──────────────────────────────────────────
func TestExtractClaudeToolNamesSingle(t *testing.T) {
tools := []any{
map[string]any{"name": "search"},
}
names := extractClaudeToolNames(tools)
if len(names) != 1 || names[0] != "search" {
t.Fatalf("expected [search], got %v", names)
}
}
func TestExtractClaudeToolNamesMultiple(t *testing.T) {
tools := []any{
map[string]any{"name": "search"},
map[string]any{"name": "calculate"},
}
names := extractClaudeToolNames(tools)
if len(names) != 2 {
t.Fatalf("expected 2 names, got %v", names)
}
}
func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) {
tools := []any{
map[string]any{"name": ""},
map[string]any{"name": "valid"},
}
names := extractClaudeToolNames(tools)
if len(names) != 1 || names[0] != "valid" {
t.Fatalf("expected [valid], got %v", names)
}
}
func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) {
tools := []any{"not a map", 42}
names := extractClaudeToolNames(tools)
if len(names) != 0 {
t.Fatalf("expected 0, got %v", names)
}
}
func TestExtractClaudeToolNamesNil(t *testing.T) {
names := extractClaudeToolNames(nil)
if len(names) != 0 {
t.Fatalf("expected 0, got %v", names)
}
}
// ─── toMessageMaps ───────────────────────────────────────────────────
func TestToMessageMapsNormal(t *testing.T) {
input := []any{
map[string]any{"role": "user", "content": "Hello"},
}
got := toMessageMaps(input)
if len(got) != 1 {
t.Fatalf("expected 1, got %d", len(got))
}
}
func TestToMessageMapsNonSlice(t *testing.T) {
got := toMessageMaps("not a slice")
if got != nil {
t.Fatalf("expected nil, got %v", got)
}
}
func TestToMessageMapsSkipsNonMap(t *testing.T) {
input := []any{"string", map[string]any{"role": "user"}, 42}
got := toMessageMaps(input)
if len(got) != 1 {
t.Fatalf("expected 1 map, got %d", len(got))
}
}
func TestToMessageMapsNil(t *testing.T) {
got := toMessageMaps(nil)
if got != nil {
t.Fatalf("expected nil, got %v", got)
}
}
// ─── extractMessageContent ──────────────────────────────────────────
func TestExtractMessageContentString(t *testing.T) {
if got := extractMessageContent("hello"); got != "hello" {
t.Fatalf("expected 'hello', got %q", got)
}
}
func TestExtractMessageContentArray(t *testing.T) {
input := []any{"part1", "part2"}
got := extractMessageContent(input)
if got != "part1\npart2" {
t.Fatalf("expected joined, got %q", got)
}
}
func TestExtractMessageContentOther(t *testing.T) {
got := extractMessageContent(42)
if got != "42" {
t.Fatalf("expected '42', got %q", got)
}
}
func TestExtractMessageContentNil(t *testing.T) {
got := extractMessageContent(nil)
if got != "<nil>" {
t.Fatalf("expected '<nil>', got %q", got)
}
}
// ─── cloneMap ────────────────────────────────────────────────────────
func TestCloneMapBasic(t *testing.T) {
original := map[string]any{"a": 1, "b": "hello"}
clone := cloneMap(original)
original["a"] = 999
if clone["a"] != 1 {
t.Fatalf("expected 1, got %v", clone["a"])
}
if clone["b"] != "hello" {
t.Fatalf("expected 'hello', got %v", clone["b"])
}
}
func TestCloneMapEmpty(t *testing.T) {
clone := cloneMap(map[string]any{})
if len(clone) != 0 {
t.Fatalf("expected empty, got %v", clone)
}
}
func TestCloneMapNested(t *testing.T) {
// cloneMap is shallow, so nested maps share references
inner := map[string]any{"key": "value"}
original := map[string]any{"nested": inner}
clone := cloneMap(original)
// Shallow clone means inner is shared
inner["key"] = "modified"
cloneNested := clone["nested"].(map[string]any)
if cloneNested["key"] != "modified" {
t.Fatal("expected shallow clone to share nested references")
}
}
// helper
func containsStr(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub))
}
func findSubstring(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}

View File

@@ -0,0 +1,58 @@
package claude
import (
"fmt"
"strings"
"ds2api/internal/config"
"ds2api/internal/util"
)
type claudeNormalizedRequest struct {
Standard util.StandardRequest
NormalizedMessages []any
}
func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
}
if _, ok := req["max_tokens"]; !ok {
req["max_tokens"] = 8192
}
normalizedMessages := normalizeClaudeMessages(messagesRaw)
payload := cloneMap(req)
payload["messages"] = normalizedMessages
toolsRequested, _ := req["tools"].([]any)
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
}
dsPayload := util.ConvertClaudeToDeepSeek(payload, store)
dsModel, _ := dsPayload["model"].(string)
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
if !ok {
thinkingEnabled = false
searchEnabled = false
}
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
toolNames := extractClaudeToolNames(toolsRequested)
return claudeNormalizedRequest{
Standard: util.StandardRequest{
Surface: "anthropic_messages",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
},
NormalizedMessages: normalizedMessages,
}, nil
}

View File

@@ -0,0 +1,38 @@
package claude
import (
"testing"
"ds2api/internal/config"
)
func TestNormalizeClaudeRequest(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{}`)
store := config.LoadStore()
req := map[string]any{
"model": "claude-opus-4-6",
"messages": []any{
map[string]any{"role": "user", "content": "hello"},
},
"stream": true,
"tools": []any{
map[string]any{"name": "search", "description": "Search"},
},
}
norm, err := normalizeClaudeRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if norm.Standard.ResolvedModel == "" {
t.Fatalf("expected resolved model")
}
if !norm.Standard.Stream {
t.Fatalf("expected stream=true")
}
if len(norm.Standard.ToolNames) == 0 {
t.Fatalf("expected tool names")
}
if norm.Standard.FinalPrompt == "" {
t.Fatalf("expected non-empty final prompt")
}
}

View File

@@ -0,0 +1,138 @@
package openai
import (
"crypto/sha256"
"encoding/binary"
"encoding/json"
"fmt"
"net/http"
"strings"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/util"
)
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
return
}
if _, ok := config.ResolveModel(h.Store, model); !ok {
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
inputs := extractEmbeddingInputs(req["input"])
if len(inputs) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.")
return
}
provider := ""
if h.Store != nil {
provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider()))
}
if provider == "" {
writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.")
return
}
switch provider {
case "mock", "deterministic", "builtin":
// supported local deterministic provider
default:
writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider))
return
}
data := make([]map[string]any, 0, len(inputs))
totalTokens := 0
for i, input := range inputs {
totalTokens += util.EstimateTokens(input)
data = append(data, map[string]any{
"object": "embedding",
"index": i,
"embedding": deterministicEmbedding(input),
})
}
writeJSON(w, http.StatusOK, map[string]any{
"object": "list",
"data": data,
"model": model,
"usage": map[string]any{
"prompt_tokens": totalTokens,
"total_tokens": totalTokens,
},
})
}
func extractEmbeddingInputs(raw any) []string {
switch v := raw.(type) {
case string:
s := strings.TrimSpace(v)
if s == "" {
return nil
}
return []string{s}
case []any:
out := make([]string, 0, len(v))
for _, item := range v {
switch iv := item.(type) {
case string:
s := strings.TrimSpace(iv)
if s != "" {
out = append(out, s)
}
case []any:
// Token array input support: convert to stable string form.
out = append(out, fmt.Sprintf("%v", iv))
default:
s := strings.TrimSpace(fmt.Sprintf("%v", iv))
if s != "" {
out = append(out, s)
}
}
}
return out
default:
return nil
}
}
func deterministicEmbedding(input string) []float64 {
// Keep response shape stable without external dependencies.
const dims = 64
out := make([]float64, dims)
seed := sha256.Sum256([]byte(input))
buf := seed[:]
for i := 0; i < dims; i++ {
if len(buf) < 4 {
next := sha256.Sum256(buf)
buf = next[:]
}
v := binary.BigEndian.Uint32(buf[:4])
buf = buf[4:]
// map [0, 2^32) -> [-1, 1]
out[i] = (float64(v)/2147483647.5 - 1.0)
}
return out
}

View File

@@ -0,0 +1,96 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/account"
"ds2api/internal/auth"
"ds2api/internal/config"
)
func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", cfgJSON)
store := config.LoadStore()
pool := account.NewPool(store)
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "unused", nil
})
return store, resolver
}
func TestEmbeddingsRouteContract(t *testing.T) {
store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
t.Run("unauthorized", func(t *testing.T) {
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("ok", func(t *testing.T) {
body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`)
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out["object"] != "list" {
t.Fatalf("unexpected object: %#v", out["object"])
}
data, _ := out["data"].([]any)
if len(data) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(data))
}
})
}
func TestEmbeddingsRouteProviderMissing(t *testing.T) {
store, resolver := newResolverWithConfigJSON(t, `{}`)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusNotImplemented {
t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
errObj, _ := out["error"].(map[string]any)
if _, ok := errObj["code"]; !ok {
t.Fatalf("expected error.code in response: %#v", out)
}
if _, ok := errObj["param"]; !ok {
t.Fatalf("expected error.param in response: %#v", out)
}
}

View File

@@ -0,0 +1,35 @@
package openai
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
rec := httptest.NewRecorder()
writeOpenAIError(rec, http.StatusBadRequest, "invalid input")
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rec.Code)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode body: %v", err)
}
errObj, _ := body["error"].(map[string]any)
if errObj["message"] != "invalid input" {
t.Fatalf("unexpected message: %v", errObj["message"])
}
if errObj["type"] != "invalid_request_error" {
t.Fatalf("unexpected type: %v", errObj["type"])
}
if errObj["code"] != "invalid_request" {
t.Fatalf("unexpected code: %v", errObj["code"])
}
if _, ok := errObj["param"]; !ok {
t.Fatal("expected param field")
}
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/config"
@@ -30,6 +31,8 @@ type Handler struct {
leaseMu sync.Mutex
streamLeases map[string]streamLease
responsesMu sync.Mutex
responses *responseStore
}
type streamLease struct {
@@ -39,13 +42,27 @@ type streamLease struct {
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/v1/models", h.ListModels)
r.Get("/v1/models/{model_id}", h.GetModel)
r.Post("/v1/chat/completions", h.ChatCompletions)
r.Post("/v1/responses", h.Responses)
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
r.Post("/v1/embeddings", h.Embeddings)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
}
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
model, ok := config.OpenAIModelByID(h.Store, modelID)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
return
}
writeJSON(w, http.StatusOK, model)
}
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
if isVercelStreamReleaseRequest(r) {
h.handleVercelStreamRelease(w, r)
@@ -74,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
if !ok {
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
return
}
messages := normalizeMessages(messagesRaw)
toolNames := []string{}
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
messages, toolNames = injectToolPrompt(messages, tools)
}
finalPrompt := util.MessagesPrepare(messages)
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
@@ -107,27 +111,20 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if util.ToBool(req["stream"]) {
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
if stdReq.Stream {
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
@@ -139,36 +136,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
finalThinking := result.Thinking
finalText := result.Text
detected := util.ParseToolCalls(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if thinkingEnabled && finalThinking != "" {
messageObj["reasoning_content"] = finalThinking
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
messageObj["content"] = nil
}
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
writeJSON(w, http.StatusOK, map[string]any{
"id": completionID,
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
"usage": map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
},
})
respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
@@ -190,9 +159,11 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
created := time.Now().Unix()
firstChunkSent := false
bufferToolContent := len(toolNames) > 0
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
var toolSieve toolStreamSieveState
toolCallsEmitted := false
streamToolCallIDs := map[int]string{}
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
@@ -235,13 +206,13 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
delta["role"] = "assistant"
firstChunkSent = true
}
sendChunk(map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{{"delta": delta, "index": 0}},
})
sendChunk(util.BuildOpenAIChatStreamChunk(
completionID,
created,
model,
[]map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)},
nil,
))
} else if bufferToolContent {
for _, evt := range flushToolSieve(&toolSieve, toolNames) {
if evt.Content == "" {
@@ -254,36 +225,25 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
delta["role"] = "assistant"
firstChunkSent = true
}
sendChunk(map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{{"delta": delta, "index": 0}},
})
sendChunk(util.BuildOpenAIChatStreamChunk(
completionID,
created,
model,
[]map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)},
nil,
))
}
}
if len(detected) > 0 || toolCallsEmitted {
finishReason = "tool_calls"
}
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
sendChunk(map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}},
"usage": map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
},
})
sendChunk(util.BuildOpenAIChatStreamChunk(
completionID,
created,
model,
[]map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)},
util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText),
))
sendDone()
}
@@ -357,6 +317,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
// Keep thinking delta only frame.
}
for _, evt := range events {
if len(evt.ToolCallDeltas) > 0 {
if !emitEarlyToolDeltas {
continue
}
toolCallsEmitted = true
tcDelta := map[string]any{
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
}
if !firstChunkSent {
tcDelta["role"] = "assistant"
firstChunkSent = true
}
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta))
continue
}
if len(evt.ToolCalls) > 0 {
toolCallsEmitted = true
tcDelta := map[string]any{
@@ -366,10 +341,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
tcDelta["role"] = "assistant"
firstChunkSent = true
}
newChoices = append(newChoices, map[string]any{
"delta": tcDelta,
"index": 0,
})
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta))
continue
}
if evt.Content != "" {
@@ -380,42 +352,22 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
contentDelta["role"] = "assistant"
firstChunkSent = true
}
newChoices = append(newChoices, map[string]any{
"delta": contentDelta,
"index": 0,
})
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta))
}
}
}
}
if len(delta) > 0 {
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta))
}
}
if len(newChoices) > 0 {
sendChunk(map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": newChoices,
})
sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil))
}
}
}
}
func normalizeMessages(raw []any) []map[string]any {
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
m, ok := item.(map[string]any)
if ok {
out = append(out, m)
}
}
return out
}
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))
@@ -444,7 +396,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
if len(toolSchemas) == 0 {
return messages, names
}
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }"
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error."
for i := range messages {
if messages[i]["role"] == "system" {
@@ -457,11 +409,47 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
return messages, names
}
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
if len(deltas) == 0 {
return nil
}
out := make([]map[string]any, 0, len(deltas))
for _, d := range deltas {
if d.Name == "" && d.Arguments == "" {
continue
}
callID, ok := ids[d.Index]
if !ok || callID == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
ids[d.Index] = callID
}
item := map[string]any{
"index": d.Index,
"id": callID,
"type": "function",
}
fn := map[string]any{}
if d.Name != "" {
fn["name"] = d.Name
}
if d.Arguments != "" {
fn["arguments"] = d.Arguments
}
if len(fn) > 0 {
item["function"] = fn
}
out = append(out, item)
}
return out
}
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]any{
"error": map[string]any{
"message": message,
"type": openAIErrorType(status),
"code": openAIErrorCode(status),
"param": nil,
},
})
}
@@ -485,3 +473,47 @@ func openAIErrorType(status int) string {
return "invalid_request_error"
}
}
func openAIErrorCode(status int) string {
switch status {
case http.StatusBadRequest:
return "invalid_request"
case http.StatusUnauthorized:
return "authentication_failed"
case http.StatusForbidden:
return "forbidden"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "not_found"
case http.StatusServiceUnavailable:
return "service_unavailable"
default:
if status >= 500 {
return "internal_error"
}
return "invalid_request"
}
}
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
for k, v := range collectOpenAIChatPassThrough(req) {
payload[k] = v
}
}
func (h *Handler) toolcallFeatureMatchEnabled() bool {
if h == nil || h.Store == nil {
return true
}
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
return mode == "" || mode == "feature_match"
}
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
if h == nil || h.Store == nil {
return true
}
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
return level == "" || level == "high"
}

View File

@@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string {
return ""
}
func streamToolCallArgumentChunks(frames []map[string]any) []string {
out := make([]string, 0, 4)
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
fn, _ := tcm["function"].(map[string]any)
if args, ok := fn["arguments"].(string); ok && args != "" {
out = append(out, args)
}
}
}
}
return out
}
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -108,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -141,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -169,7 +189,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -190,6 +210,66 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
}
}
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"下面是示例:"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"请勿执行。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) == 0 {
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
}
if msg["content"] != nil {
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
}
}
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "stop" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if _, ok := msg["tool_calls"]; ok {
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
}
}
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -377,9 +457,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"前置正文A。"}`,
`data: {"p":"response/content","v":"下面是示例:"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"后置正文B。"}`,
`data: {"p":"response/content","v":"请勿执行。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
@@ -392,10 +472,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String())
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -409,9 +486,95 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
}
}
got := content.String()
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
}
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"我将调用工具。"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "我将调用工具。") {
t.Fatalf("expected leading text to keep streaming, got=%q", got)
}
if strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("unexpected raw tool json leak, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "接下来我会继续说明。") {
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
}
if strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("unexpected raw tool json leak, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
@@ -495,16 +658,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
}
}
}
got := strings.ToLower(content.String())
if strings.Contains(got, "tool_calls") {
t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String())
}
if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") {
got := content.String()
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
}
if !strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
}
}
func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) {
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
@@ -533,7 +696,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.
}
}
}
if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") {
t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String())
if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") {
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
}
}
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
argChunks := streamToolCallArgumentChunks(frames)
if len(argChunks) < 2 {
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
}
joined := strings.Join(argChunks, "")
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
t.Fatalf("unexpected merged arguments stream: %q", joined)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

View File

@@ -0,0 +1,192 @@
package openai
import (
"encoding/json"
"fmt"
"strings"
)
func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
msg, ok := item.(map[string]any)
if !ok {
continue
}
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
switch role {
case "assistant":
content := normalizeOpenAIContentForPrompt(msg["content"])
toolCalls := formatAssistantToolCallsForPrompt(msg)
combined := joinNonEmpty(content, toolCalls)
if combined == "" {
continue
}
out = append(out, map[string]any{
"role": "assistant",
"content": combined,
})
case "tool", "function":
out = append(out, map[string]any{
"role": "user",
"content": formatToolResultForPrompt(msg),
})
case "user", "system":
out = append(out, map[string]any{
"role": role,
"content": normalizeOpenAIContentForPrompt(msg["content"]),
})
default:
content := normalizeOpenAIContentForPrompt(msg["content"])
if content == "" {
continue
}
if role == "" {
role = "user"
}
out = append(out, map[string]any{
"role": role,
"content": content,
})
}
}
return out
}
func formatAssistantToolCallsForPrompt(msg map[string]any) string {
entries := make([]string, 0)
if calls, ok := msg["tool_calls"].([]any); ok {
for i, item := range calls {
call, ok := item.(map[string]any)
if !ok {
continue
}
id := strings.TrimSpace(asString(call["id"]))
if id == "" {
id = fmt.Sprintf("call_%d", i+1)
}
name := strings.TrimSpace(asString(call["name"]))
args := ""
if fn, ok := call["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(asString(fn["name"]))
}
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
}
if name == "" {
name = "unknown"
}
if args == "" {
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
}
if args == "" {
args = normalizeOpenAIArgumentsForPrompt(call["input"])
}
if args == "" {
args = "{}"
}
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args))
}
}
if legacy, ok := msg["function_call"].(map[string]any); ok {
name := strings.TrimSpace(asString(legacy["name"]))
if name == "" {
name = "unknown"
}
args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"])
if args == "" {
args = "{}"
}
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args))
}
return strings.Join(entries, "\n\n")
}
func formatToolResultForPrompt(msg map[string]any) string {
toolCallID := strings.TrimSpace(asString(msg["tool_call_id"]))
if toolCallID == "" {
toolCallID = strings.TrimSpace(asString(msg["id"]))
}
if toolCallID == "" {
toolCallID = "unknown"
}
name := strings.TrimSpace(asString(msg["name"]))
if name == "" {
name = "unknown"
}
content := normalizeOpenAIContentForPrompt(msg["content"])
if content == "" {
content = "null"
}
return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content)
}
func normalizeOpenAIContentForPrompt(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, item := range x {
m, ok := item.(map[string]any)
if !ok {
continue
}
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
if t != "text" && t != "output_text" && t != "input_text" {
continue
}
if text := asString(m["text"]); text != "" {
parts = append(parts, text)
continue
}
if text := asString(m["content"]); text != "" {
parts = append(parts, text)
}
}
return strings.Join(parts, "\n")
default:
return marshalToPromptString(v)
}
}
func normalizeOpenAIArgumentsForPrompt(v any) string {
switch x := v.(type) {
case string:
return strings.TrimSpace(x)
default:
return marshalToPromptString(v)
}
}
func marshalToPromptString(v any) string {
b, err := json.Marshal(v)
if err != nil {
return strings.TrimSpace(fmt.Sprintf("%v", v))
}
return string(b)
}
func asString(v any) string {
if s, ok := v.(string); ok {
return s
}
return ""
}
func joinNonEmpty(parts ...string) string {
nonEmpty := make([]string, 0, len(parts))
for _, p := range parts {
if strings.TrimSpace(p) == "" {
continue
}
nonEmpty = append(nonEmpty, p)
}
return strings.Join(nonEmpty, "\n\n")
}

View File

@@ -0,0 +1,121 @@
package openai
import (
"strings"
"testing"
"ds2api/internal/util"
)
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) {
raw := []any{
map[string]any{"role": "system", "content": "You are helpful"},
map[string]any{"role": "user", "content": "查北京天气"},
map[string]any{
"role": "assistant",
"content": nil,
"tool_calls": []any{
map[string]any{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": "{\"city\":\"beijing\"}",
},
},
},
},
map[string]any{
"role": "tool",
"tool_call_id": "call_1",
"name": "get_weather",
"content": "{\"temp\":18}",
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
if len(normalized) != 4 {
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
}
assistantContent, _ := normalized[2]["content"].(string)
if !strings.Contains(assistantContent, "tool_call_id: call_1") ||
!strings.Contains(assistantContent, "function.name: get_weather") ||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
}
toolContent, _ := normalized[3]["content"].(string)
if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") {
t.Fatalf("tool result not serialized correctly: %q", toolContent)
}
prompt := util.MessagesPrepare(normalized)
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") {
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
}
}
func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) {
raw := []any{
map[string]any{
"role": "tool",
"tool_call_id": "call_2",
"name": "get_weather",
"content": map[string]any{
"temp": 18,
"condition": "sunny",
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
t.Fatalf("expected serialized object in tool content, got %q", got)
}
}
func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
raw := []any{
map[string]any{
"role": "tool",
"tool_call_id": "call_3",
"name": "read_file",
"content": []any{
map[string]any{"type": "input_text", "text": "line-1"},
map[string]any{"type": "output_text", "text": "line-2"},
map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"},
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, "line-1\nline-2") {
t.Fatalf("expected joined text blocks, got %q", got)
}
}
func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
raw := []any{
map[string]any{
"role": "function",
"tool_call_id": "call_4",
"name": "legacy_tool",
"content": map[string]any{
"ok": true,
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
if len(normalized) != 1 {
t.Fatalf("expected one normalized message, got %d", len(normalized))
}
if normalized[0]["role"] != "user" {
t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"])
}
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) {
t.Fatalf("unexpected normalized function-role content: %q", got)
}
}

View File

@@ -0,0 +1,46 @@
package openai
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestGetModelRouteDirectAndAlias(t *testing.T) {
h := &Handler{}
r := chi.NewRouter()
RegisterRoutes(r, h)
t.Run("direct", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("alias", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String())
}
})
}
func TestGetModelRouteNotFound(t *testing.T) {
h := &Handler{}
r := chi.NewRouter()
RegisterRoutes(r, h)
req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
}
}

View File

@@ -0,0 +1,12 @@
package openai
import "ds2api/internal/util"
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
toolNames := []string{}
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
messages, toolNames = injectToolPrompt(messages, tools)
}
return util.MessagesPrepare(messages), toolNames
}

View File

@@ -0,0 +1,80 @@
package openai
import (
"strings"
"testing"
)
func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) {
messages := []any{
map[string]any{"role": "user", "content": "查北京天气"},
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_1",
"function": map[string]any{
"name": "get_weather",
"arguments": "{\"city\":\"beijing\"}",
},
},
},
},
map[string]any{
"role": "tool",
"tool_call_id": "call_1",
"name": "get_weather",
"content": map[string]any{"temp": 18, "condition": "sunny"},
},
}
tools := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "get_weather",
"description": "Get weather",
"parameters": map[string]any{
"type": "object",
},
},
},
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
t.Fatalf("unexpected tool names: %#v", toolNames)
}
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
!strings.Contains(finalPrompt, "Tool result:") ||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
}
}
func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) {
messages := []any{
map[string]any{"role": "system", "content": "You are helpful"},
map[string]any{"role": "user", "content": "请调用工具"},
}
tools := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
"description": "search docs",
"parameters": map[string]any{
"type": "object",
},
},
},
}
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
}
}

View File

@@ -0,0 +1,109 @@
package openai
import (
"sync"
"time"
"ds2api/internal/auth"
)
type storedResponse struct {
Owner string
Value map[string]any
ExpiresAt time.Time
}
type responseStore struct {
mu sync.Mutex
ttl time.Duration
items map[string]storedResponse
}
func newResponseStore(ttl time.Duration) *responseStore {
if ttl <= 0 {
ttl = 15 * time.Minute
}
return &responseStore{
ttl: ttl,
items: make(map[string]storedResponse),
}
}
func responseStoreKey(owner, id string) string {
return owner + "\x00" + id
}
func responseStoreOwner(a *auth.RequestAuth) string {
if a == nil {
return ""
}
return a.CallerID
}
func (s *responseStore) put(owner, id string, value map[string]any) {
if s == nil || owner == "" || id == "" || value == nil {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked(now)
s.items[responseStoreKey(owner, id)] = storedResponse{
Owner: owner,
Value: cloneAnyMap(value),
ExpiresAt: now.Add(s.ttl),
}
}
func (s *responseStore) get(owner, id string) (map[string]any, bool) {
if s == nil || owner == "" || id == "" {
return nil, false
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked(now)
item, ok := s.items[responseStoreKey(owner, id)]
if !ok {
return nil, false
}
if item.Owner != owner {
return nil, false
}
return cloneAnyMap(item.Value), true
}
func (s *responseStore) sweepLocked(now time.Time) {
for k, v := range s.items {
if now.After(v.ExpiresAt) {
delete(s.items, k)
}
}
}
func cloneAnyMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func (h *Handler) getResponseStore() *responseStore {
if h == nil {
return nil
}
h.responsesMu.Lock()
defer h.responsesMu.Unlock()
if h.responses == nil {
ttl := 15 * time.Minute
if h.Store != nil {
ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second
}
h.responses = newResponseStore(ttl)
}
return h.responses
}

View File

@@ -0,0 +1,73 @@
package openai
import (
"testing"
"time"
)
func TestNormalizeResponsesInputAsMessagesString(t *testing.T) {
msgs := normalizeResponsesInputAsMessages("hello")
if len(msgs) != 1 {
t.Fatalf("expected one message, got %d", len(msgs))
}
m, _ := msgs[0].(map[string]any)
if m["role"] != "user" || m["content"] != "hello" {
t.Fatalf("unexpected message: %#v", m)
}
}
func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
req := map[string]any{
"model": "gpt-4.1",
"input": "ping",
"instructions": "system text",
}
msgs := responsesMessagesFromRequest(req)
if len(msgs) != 2 {
t.Fatalf("expected two messages, got %d", len(msgs))
}
sys, _ := msgs[0].(map[string]any)
if sys["role"] != "system" {
t.Fatalf("unexpected first message: %#v", sys)
}
}
func TestExtractEmbeddingInputs(t *testing.T) {
got := extractEmbeddingInputs([]any{"a", "b"})
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
t.Fatalf("unexpected inputs: %#v", got)
}
}
func TestDeterministicEmbeddingStable(t *testing.T) {
a := deterministicEmbedding("hello")
b := deterministicEmbedding("hello")
if len(a) != 64 || len(b) != 64 {
t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b))
}
for i := range a {
if a[i] != b[i] {
t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i])
}
}
}
func TestResponseStorePutGet(t *testing.T) {
st := newResponseStore(100 * time.Millisecond)
st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"})
got, ok := st.get("owner_1", "resp_1")
if !ok {
t.Fatal("expected stored response")
}
if got["id"] != "resp_1" {
t.Fatalf("unexpected response payload: %#v", got)
}
}
func TestResponseStoreTenantIsolation(t *testing.T) {
st := newResponseStore(100 * time.Millisecond)
st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"})
if _, ok := st.get("owner_b", "resp_1"); ok {
t.Fatal("expected owner_b to be isolated from owner_a response")
}
}

View File

@@ -0,0 +1,308 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/sse"
"ds2api/internal/util"
)
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.DetermineCaller(r)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, err.Error())
return
}
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
if id == "" {
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
return
}
owner := responseStoreOwner(a)
if owner == "" {
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
return
}
st := h.getResponseStore()
item, ok := st.get(owner, id)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
return
}
writeJSON(w, http.StatusOK, item)
}
func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
r = r.WithContext(auth.WithAuth(r.Context(), a))
owner := responseStoreOwner(a)
if owner == "" {
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
return
}
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if stdReq.Stream {
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
h.getResponseStore().put(owner, responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
}
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
canFlush := rc.Flush() == nil
sendEvent := func(event string, payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = w.Write([]byte("event: " + event + "\n"))
_, _ = w.Write([]byte("data: "))
_, _ = w.Write(b)
_, _ = w.Write([]byte("\n\n"))
if canFlush {
_ = rc.Flush()
}
}
sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model))
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
var sieve toolStreamSieveState
thinking := strings.Builder{}
text := strings.Builder{}
toolCallsEmitted := false
streamToolCallIDs := map[int]string{}
finalize := func() {
finalThinking := thinking.String()
finalText := text.String()
if bufferToolContent {
for _, evt := range flushToolSieve(&sieve, toolNames) {
if evt.Content != "" {
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content))
}
if len(evt.ToolCalls) > 0 {
toolCallsEmitted = true
sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
}
}
}
obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames)
if toolCallsEmitted {
obj["status"] = "completed"
}
h.getResponseStore().put(owner, responseID, obj)
sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
if canFlush {
_ = rc.Flush()
}
}
for {
select {
case <-r.Context().Done():
return
case parsed, ok := <-parsedLines:
if !ok {
_ = <-done
finalize()
return
}
if !parsed.Parsed {
continue
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
finalize()
return
}
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
continue
}
if p.Type == "thinking" {
if !thinkingEnabled {
continue
}
thinking.WriteString(p.Text)
sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text))
continue
}
text.WriteString(p.Text)
if !bufferToolContent {
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text))
continue
}
for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) {
if evt.Content != "" {
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !emitEarlyToolDeltas {
continue
}
toolCallsEmitted = true
sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs)))
}
if len(evt.ToolCalls) > 0 {
toolCallsEmitted = true
sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
}
}
}
}
}
}
func responsesMessagesFromRequest(req map[string]any) []any {
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
if rawInput, ok := req["input"]; ok {
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
}
return nil
}
func prependInstructionMessage(messages []any, instructions any) []any {
sys, _ := instructions.(string)
sys = strings.TrimSpace(sys)
if sys == "" {
return messages
}
out := make([]any, 0, len(messages)+1)
out = append(out, map[string]any{"role": "system", "content": sys})
out = append(out, messages...)
return out
}
func normalizeResponsesInputAsMessages(input any) []any {
switch v := input.(type) {
case string:
if strings.TrimSpace(v) == "" {
return nil
}
return []any{map[string]any{"role": "user", "content": v}}
case []any:
if len(v) == 0 {
return nil
}
// If caller already provides role-shaped items, keep as-is.
if first, ok := v[0].(map[string]any); ok {
if _, hasRole := first["role"]; hasRole {
return v
}
}
parts := make([]string, 0, len(v))
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
parts = append(parts, txt)
continue
}
}
}
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
parts = append(parts, s)
}
}
if len(parts) == 0 {
return nil
}
return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
case map[string]any:
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
return []any{map[string]any{"role": "user", "content": txt}}
}
if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
return []any{map[string]any{"role": "user", "content": content}}
}
}
return nil
}

View File

@@ -0,0 +1,176 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/account"
"ds2api/internal/auth"
"ds2api/internal/config"
)
func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
store := config.LoadStore()
pool := account.NewPool(store)
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "unused", nil
})
return store, resolver
}
func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
}`)
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0")
store := config.LoadStore()
pool := account.NewPool(store)
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "unused", nil
})
return store, resolver
}
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
t.Helper()
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer "+token)
a, err := resolver.Determine(req)
if err != nil {
t.Fatalf("determine auth failed: %v", err)
}
return a
}
func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) {
store, resolver := newDirectTokenResolver(t)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
ownerA := responseStoreOwner(authForToken(t, resolver, "token-a"))
h.getResponseStore().put(ownerA, "resp_test", map[string]any{
"id": "resp_test",
"object": "response",
})
t.Run("unauthorized", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("cross-tenant-not-found", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer token-b")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("same-tenant-ok", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer token-a")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode body failed: %v", err)
}
if body["id"] != "resp_test" {
t.Fatalf("unexpected body: %#v", body)
}
})
}
func TestResponsesRouteValidationContract(t *testing.T) {
store, resolver := newDirectTokenResolver(t)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
tests := []struct {
name string
body string
}{
{name: "missing_model", body: `{"input":"hello"}`},
{name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body))
req.Header.Set("Authorization", "Bearer token-a")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
errObj, _ := out["error"].(map[string]any)
if _, ok := errObj["code"]; !ok {
t.Fatalf("expected error.code: %#v", out)
}
if _, ok := errObj["param"]; !ok {
t.Fatalf("expected error.param: %#v", out)
}
})
}
}
func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) {
store, resolver := newManagedKeyResolver(t)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
ownerReq.Header.Set("Authorization", "Bearer managed-key")
ownerAuth, err := resolver.DetermineCaller(ownerReq)
if err != nil {
t.Fatalf("determine caller failed: %v", err)
}
owner := responseStoreOwner(ownerAuth)
h.getResponseStore().put(owner, "resp_test", map[string]any{
"id": "resp_test",
"object": "response",
})
occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
occupyReq.Header.Set("Authorization", "Bearer managed-key")
occupied, err := resolver.Determine(occupyReq)
if err != nil {
t.Fatalf("expected first acquire to succeed: %v", err)
}
defer resolver.Release(occupied)
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer managed-key")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String())
}
}

View File

@@ -0,0 +1,122 @@
package openai
import (
"bufio"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
}
responseObj, _ := completed["response"].(map[string]any)
outputText, _ := responseObj["output_text"].(string)
if outputText != "" {
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
}
output, _ := responseObj["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "tool_calls" {
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
}
toolCalls, _ := first["tool_calls"].([]any)
if len(toolCalls) == 0 {
t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"])
}
call0, _ := toolCalls[0].(map[string]any)
if call0["name"] != "read_file" {
t.Fatalf("unexpected tool call name: %#v", call0["name"])
}
if strings.Contains(outputText, `"tool_calls"`) {
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
}
}
func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
tail := `{"tool_calls":[{"name":"read_file","input":`
streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
}
responseObj, _ := completed["response"].(map[string]any)
outputText, _ := responseObj["output_text"].(string)
if strings.Count(outputText, tail) > 1 {
t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText)
}
}
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "event: ") {
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
matched = evt == targetEvent
continue
}
if !matched || !strings.HasPrefix(line, "data: ") {
continue
}
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if raw == "" || raw == "[DONE]" {
continue
}
var payload map[string]any
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return nil, false
}
return payload, true
}
return nil, false
}

View File

@@ -0,0 +1,104 @@
package openai
import (
"fmt"
"strings"
"ds2api/internal/config"
"ds2api/internal/util"
)
func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
}
resolvedModel, ok := config.ResolveModel(store, model)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
responseModel := strings.TrimSpace(model)
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
Surface: "openai_chat",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: resolvedModel,
ResponseModel: responseModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}
func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.")
}
resolvedModel, ok := config.ResolveModel(store, model)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
// Keep width-control as an explicit policy hook even if current default is true.
allowWideInput := true
if store != nil {
allowWideInput = store.CompatWideInputStrictOutput()
}
var messagesRaw []any
if allowWideInput {
messagesRaw = responsesMessagesFromRequest(req)
} else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
messagesRaw = msgs
}
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
Surface: "openai_responses",
RequestedModel: model,
ResolvedModel: resolvedModel,
ResponseModel: model,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
out := map[string]any{}
for _, k := range []string{
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"presence_penalty",
"frequency_penalty",
"stop",
} {
if v, ok := req[k]; ok {
out[k] = v
}
}
return out
}

View File

@@ -0,0 +1,60 @@
package openai
import (
"testing"
"ds2api/internal/config"
)
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{}`)
return config.LoadStore()
}
func TestNormalizeOpenAIChatRequest(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-5-codex",
"messages": []any{
map[string]any{"role": "user", "content": "hello"},
},
"temperature": 0.3,
"stream": true,
}
n, err := normalizeOpenAIChatRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if n.ResolvedModel != "deepseek-reasoner" {
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
}
if !n.Stream {
t.Fatalf("expected stream=true")
}
if _, ok := n.PassThrough["temperature"]; !ok {
t.Fatalf("expected temperature passthrough")
}
if n.FinalPrompt == "" {
t.Fatalf("expected non-empty final prompt")
}
}
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-4o",
"input": "ping",
"instructions": "system",
}
n, err := normalizeOpenAIResponsesRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if n.ResolvedModel != "deepseek-chat" {
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
}
if len(n.Messages) != 2 {
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
}
}

View File

@@ -7,14 +7,40 @@ import (
)
type toolStreamSieveState struct {
pending strings.Builder
capture strings.Builder
capturing bool
pending strings.Builder
capture strings.Builder
capturing bool
recentTextTail string
toolNameSent bool
toolName string
toolArgsStart int
toolArgsSent int
toolArgsString bool
toolArgsDone bool
}
type toolStreamEvent struct {
Content string
ToolCalls []util.ParsedToolCall
Content string
ToolCalls []util.ParsedToolCall
ToolCallDeltas []toolCallDelta
}
type toolCallDelta struct {
Index int
Name string
Arguments string
}
const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
s.toolArgsSent = -1
s.toolArgsString = false
s.toolArgsDone = false
}
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
@@ -32,13 +58,27 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
state.capture.WriteString(state.pending.String())
state.pending.Reset()
}
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
}
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
if !ready {
if state.capture.Len() > toolSieveCaptureLimit {
content := state.capture.String()
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
continue
}
break
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
if len(calls) > 0 {
@@ -58,11 +98,13 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
if start >= 0 {
prefix := pending[:start]
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
state.pending.Reset()
state.capture.WriteString(pending[start:])
state.capturing = true
state.resetIncrementalToolState()
continue
}
@@ -72,6 +114,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
}
state.pending.Reset()
state.pending.WriteString(hold)
state.noteText(safe)
events = append(events, toolStreamEvent{Content: safe})
}
@@ -84,25 +127,34 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
}
events := processToolSieveChunk(state, "", toolNames)
if state.capturing {
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames)
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
if ready {
if consumedPrefix != "" {
state.noteText(consumedPrefix)
events = append(events, toolStreamEvent{Content: consumedPrefix})
}
if len(consumedCalls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
}
if consumedSuffix != "" {
state.noteText(consumedSuffix)
events = append(events, toolStreamEvent{Content: consumedSuffix})
}
} else {
// Incomplete captured tool JSON at stream end: suppress raw capture.
content := state.capture.String()
if content != "" {
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
}
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
}
if state.pending.Len() > 0 {
events = append(events, toolStreamEvent{Content: state.pending.String()})
content := state.pending.String()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
state.pending.Reset()
}
return events
@@ -144,17 +196,26 @@ func findToolSegmentStart(s string) int {
return -1
}
lower := strings.ToLower(s)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return -1
offset := 0
for {
keyRel := strings.Index(lower[offset:], "tool_calls")
if keyRel < 0 {
return -1
}
keyIdx := offset + keyRel
start := strings.LastIndex(s[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
if !insideCodeFence(s[:start]) {
return start
}
offset = keyIdx + len("tool_calls")
}
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
return start
}
return keyIdx
}
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
captured := state.capture.String()
if captured == "" {
return "", nil, "", false
}
@@ -171,13 +232,25 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal
if !ok {
return "", nil, "", false
}
parsed := util.ParseToolCalls(obj, toolNames)
if len(parsed) == 0 {
// `tool_calls` key exists but strict JSON parse failed.
// Drop the captured object body to avoid leaking raw tool JSON.
return captured[:start], nil, captured[end:], true
prefixPart := captured[:start]
suffixPart := captured[end:]
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
return captured[:start], parsed, captured[end:], true
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
if state.toolNameSent {
return prefixPart, nil, suffixPart, true
}
return captured, nil, "", true
}
if state.toolNameSent {
if len(parsed) > 1 {
return prefixPart, parsed[1:], suffixPart, true
}
return prefixPart, nil, suffixPart, true
}
return prefixPart, parsed, suffixPart, true
}
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
@@ -221,3 +294,352 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
}
return "", 0, false
}
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
captured := state.capture.String()
if captured == "" {
return nil
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return nil
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return nil
}
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
}
deltas := make([]toolCallDelta, 0, 2)
if state.toolName == "" {
name, ok := extractToolCallName(captured, callStart)
if !ok || name == "" {
return nil
}
state.toolName = name
}
if state.toolArgsStart < 0 {
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
if ok {
state.toolArgsString = stringMode
if stringMode {
state.toolArgsStart = argsStart + 1
} else {
state.toolArgsStart = argsStart
}
state.toolArgsSent = state.toolArgsStart
}
}
if !state.toolNameSent {
if state.toolArgsStart < 0 {
return nil
}
state.toolNameSent = true
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
}
if state.toolArgsStart < 0 || state.toolArgsDone {
return deltas
}
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
if !ok {
return deltas
}
if end > state.toolArgsSent {
deltas = append(deltas, toolCallDelta{
Index: 0,
Arguments: captured[state.toolArgsSent:end],
})
state.toolArgsSent = end
}
if complete {
state.toolArgsDone = true
}
return deltas
}
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return -1, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return -1, false
}
return i, true
}
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
i := keyIdx + len("tool_calls")
for i < len(text) && text[i] != ':' {
i++
}
if i >= len(text) {
return -1, false
}
i = skipSpaces(text, i+1)
if i >= len(text) || text[i] != '[' {
return -1, false
}
return i, true
}
func extractToolCallName(text string, callStart int) (string, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return "", false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
return "", false
}
}
name, _, ok := parseJSONStringLiteral(text, valueStart)
if !ok {
return "", false
}
return name, true
}
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
keys := []string{"input", "arguments", "args", "parameters", "params"}
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
if !ok {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return -1, false, false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
if !ok {
return -1, false, false
}
}
if valueStart >= len(text) {
return -1, false, false
}
ch := text[valueStart]
if ch == '{' || ch == '[' {
return valueStart, false, true
}
if ch == '"' {
return valueStart, true, true
}
return -1, false, false
}
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
if start < 0 || start > len(text) {
return 0, false, false
}
if stringMode {
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return i, true, true
}
}
return len(text), false, true
}
if start >= len(text) {
return start, false, false
}
if text[start] != '{' && text[start] != '[' {
return 0, false, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' || ch == '[' {
depth++
continue
}
if ch == '}' || ch == ']' {
depth--
if depth == 0 {
return i + 1, true, true
}
}
}
return len(text), false, true
}
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
return 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := objStart; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
if depth == 1 {
key, end, ok := parseJSONStringLiteral(text, i)
if !ok {
return 0, false
}
j := skipSpaces(text, end)
if j >= len(text) || text[j] != ':' {
i = end - 1
continue
}
j = skipSpaces(text, j+1)
if j >= len(text) {
return 0, false
}
if containsKey(keys, key) {
return j, true
}
i = j - 1
continue
}
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
break
}
}
}
return 0, false
}
func findFunctionObjectStart(text string, callStart int) (int, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
return -1, false
}
return valueStart, true
}
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '"' {
return "", 0, false
}
var b strings.Builder
escaped := false
for i := start + 1; i < len(text); i++ {
ch := text[i]
if escaped {
b.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return b.String(), i + 1, true
}
b.WriteByte(ch)
}
return "", 0, false
}
func containsKey(keys []string, value string) bool {
for _, k := range keys {
if k == value {
return true
}
}
return false
}
func skipSpaces(text string, i int) int {
for i < len(text) {
switch text[i] {
case ' ', '\t', '\n', '\r':
i++
default:
return i
}
}
return i
}
func (s *toolStreamSieveState) noteText(content string) {
if strings.TrimSpace(content) == "" {
return
}
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
}
func appendTail(prev, next string, max int) string {
if max <= 0 {
return ""
}
combined := prev + next
if len(combined) <= max {
return combined
}
return combined[len(combined)-max:]
}
func looksLikeToolExampleContext(text string) bool {
return insideCodeFence(text)
}
func insideCodeFence(text string) bool {
if text == "" {
return false
}
return strings.Count(text, "```")%2 == 1
}

View File

@@ -56,24 +56,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
if !ok {
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
if !stdReq.Stream {
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
messages := normalizeMessages(messagesRaw)
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
messages, _ = injectToolPrompt(messages, tools)
}
finalPrompt := util.MessagesPrepare(messages)
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
@@ -93,14 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
return
}
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
payload := stdReq.CompletionPayload(sessionID)
leaseID := h.holdStreamLease(a)
if leaseID == "" {
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
@@ -108,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
}
leased = true
writeJSON(w, http.StatusOK, map[string]any{
"session_id": sessionID,
"lease_id": leaseID,
"model": model,
"final_prompt": finalPrompt,
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
"deepseek_token": a.DeepSeekToken,
"pow_header": powHeader,
"payload": payload,
"session_id": sessionID,
"lease_id": leaseID,
"model": stdReq.ResponseModel,
"final_prompt": stdReq.FinalPrompt,
"thinking_enabled": stdReq.Thinking,
"search_enabled": stdReq.Search,
"tool_names": stdReq.ToolNames,
"toolcall_feature_match": h.toolcallFeatureMatchEnabled(),
"toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(),
"deepseek_token": a.DeepSeekToken,
"pow_header": powHeader,
"payload": payload,
})
}

View File

@@ -56,7 +56,14 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
preview = token
}
}
items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview})
items = append(items, map[string]any{
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"has_password": acc.Password != "",
"has_token": token != "",
"token_preview": preview,
})
}
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
}
@@ -94,7 +101,7 @@ func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
err := h.Store.Update(func(c *config.Config) error {
idx := -1
for i, a := range c.Accounts {
if a.Email == identifier || a.Mobile == identifier {
if accountMatchesIdentifier(a, identifier) {
idx = i
break
}
@@ -122,10 +129,10 @@ func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
_ = json.NewDecoder(r.Body).Decode(&req)
identifier, _ := req["identifier"].(string)
if strings.TrimSpace(identifier) == "" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识email mobile"})
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile"})
return
}
acc, ok := h.Store.FindAccount(identifier)
acc, ok := findAccountByIdentifier(h.Store, identifier)
if !ok {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
return

View File

@@ -0,0 +1,138 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/account"
"ds2api/internal/config"
)
func newAdminTestHandler(t *testing.T, raw string) *Handler {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", raw)
t.Setenv("CONFIG_JSON", "")
store := config.LoadStore()
return &Handler{
Store: store,
Pool: account.NewPool(store),
}
}
func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
h := newAdminTestHandler(t, `{
"accounts":[{"token":"token-only-account"}]
}`)
req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil)
rec := httptest.NewRecorder()
h.listAccounts(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode response failed: %v", err)
}
items, _ := payload["items"].([]any)
if len(items) != 1 {
t.Fatalf("expected 1 item, got %d", len(items))
}
first, _ := items[0].(map[string]any)
identifier, _ := first["identifier"].(string)
if identifier == "" {
t.Fatalf("expected non-empty identifier: %#v", first)
}
if !strings.HasPrefix(identifier, "token:") {
t.Fatalf("expected token synthetic identifier, got %q", identifier)
}
}
func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) {
h := newAdminTestHandler(t, `{
"accounts":[{"token":"token-only-account"}]
}`)
accounts := h.Store.Accounts()
if len(accounts) != 1 {
t.Fatalf("expected 1 account, got %d", len(accounts))
}
id := accounts[0].Identifier()
if id == "" {
t.Fatal("expected token-only synthetic identifier")
}
r := chi.NewRouter()
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
if got := len(h.Store.Accounts()); got != 0 {
t.Fatalf("expected account removed, remaining=%d", got)
}
}
func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
h := newAdminTestHandler(t, `{
"accounts":[{"email":"u@example.com","mobile":"13800138000","password":"pwd"}]
}`)
r := chi.NewRouter()
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/13800138000", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
if got := len(h.Store.Accounts()); got != 0 {
t.Fatalf("expected account removed, remaining=%d", got)
}
}
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
h := newAdminTestHandler(t, `{
"accounts":[
{"email":"u@example.com","mobile":"13800138000","password":"pwd"},
{"token":"token-only-account"}
]
}`)
accByMobile, ok := findAccountByIdentifier(h.Store, "13800138000")
if !ok {
t.Fatal("expected find by mobile")
}
if accByMobile.Email != "u@example.com" {
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
}
tokenOnlyID := ""
for _, acc := range h.Store.Accounts() {
if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" {
tokenOnlyID = acc.Identifier()
break
}
}
if tokenOnlyID == "" {
t.Fatal("expected token-only account identifier")
}
accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID)
if !ok {
t.Fatalf("expected find by token-only id=%q", tokenOnlyID)
}
if accByTokenOnly.Token != "token-only-account" {
t.Fatalf("unexpected token-only account: %#v", accByTokenOnly)
}
}

View File

@@ -37,6 +37,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
}
}
accounts = append(accounts, map[string]any{
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"has_password": strings.TrimSpace(acc.Password) != "",

View File

@@ -81,3 +81,34 @@ func statusOr(v int, d int) int {
}
return v
}
func accountMatchesIdentifier(acc config.Account, identifier string) bool {
id := strings.TrimSpace(identifier)
if id == "" {
return false
}
if strings.TrimSpace(acc.Email) == id {
return true
}
if strings.TrimSpace(acc.Mobile) == id {
return true
}
return acc.Identifier() == id
}
func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) {
id := strings.TrimSpace(identifier)
if id == "" {
return config.Account{}, false
}
if acc, ok := store.FindAccount(id); ok {
return acc, true
}
accounts := store.Snapshot().Accounts
for _, acc := range accounts {
if accountMatchesIdentifier(acc, id) {
return acc, true
}
}
return config.Account{}, false
}

View File

@@ -0,0 +1,240 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"ds2api/internal/config"
)
// ─── reverseAccounts ─────────────────────────────────────────────────
func TestReverseAccountsEmpty(t *testing.T) {
a := []config.Account{}
reverseAccounts(a)
if len(a) != 0 {
t.Fatal("expected empty")
}
}
func TestReverseAccountsTwoElements(t *testing.T) {
a := []config.Account{
{Email: "a@test.com"},
{Email: "b@test.com"},
}
reverseAccounts(a)
if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" {
t.Fatalf("unexpected order after reverse: %v", a)
}
}
func TestReverseAccountsThreeElements(t *testing.T) {
a := []config.Account{
{Email: "1@test.com"},
{Email: "2@test.com"},
{Email: "3@test.com"},
}
reverseAccounts(a)
if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" {
t.Fatalf("unexpected order: %v", a)
}
}
// ─── intFromQuery edge cases ─────────────────────────────────────────
func TestIntFromQueryPresent(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=5", nil)
if got := intFromQuery(req, "limit", 10); got != 5 {
t.Fatalf("expected 5, got %d", got)
}
}
func TestIntFromQueryMissing(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if got := intFromQuery(req, "limit", 10); got != 10 {
t.Fatalf("expected default 10, got %d", got)
}
}
func TestIntFromQueryInvalid(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=abc", nil)
if got := intFromQuery(req, "limit", 10); got != 10 {
t.Fatalf("expected default 10 for invalid, got %d", got)
}
}
func TestIntFromQueryNegative(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=-3", nil)
if got := intFromQuery(req, "limit", 10); got != -3 {
t.Fatalf("expected -3, got %d", got)
}
}
func TestIntFromQueryZero(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=0", nil)
if got := intFromQuery(req, "limit", 10); got != 0 {
t.Fatalf("expected 0, got %d", got)
}
}
// ─── nilIfEmpty ──────────────────────────────────────────────────────
func TestNilIfEmptyEmpty(t *testing.T) {
if nilIfEmpty("") != nil {
t.Fatal("expected nil for empty string")
}
}
func TestNilIfEmptyNonEmpty(t *testing.T) {
if nilIfEmpty("hello") != "hello" {
t.Fatal("expected 'hello'")
}
}
// ─── nilIfZero ───────────────────────────────────────────────────────
func TestNilIfZeroZero(t *testing.T) {
if nilIfZero(0) != nil {
t.Fatal("expected nil for zero")
}
}
func TestNilIfZeroNonZero(t *testing.T) {
if nilIfZero(42) != int64(42) {
t.Fatal("expected 42")
}
}
func TestNilIfZeroNegative(t *testing.T) {
if nilIfZero(-1) != int64(-1) {
t.Fatal("expected -1")
}
}
// ─── toStringSlice ───────────────────────────────────────────────────
func TestToStringSliceFromAnySlice(t *testing.T) {
input := []any{"a", "b", "c"}
got, ok := toStringSlice(input)
if !ok || len(got) != 3 {
t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok)
}
if got[0] != "a" || got[1] != "b" || got[2] != "c" {
t.Fatalf("unexpected values: %#v", got)
}
}
func TestToStringSliceFromMixed(t *testing.T) {
input := []any{"hello", 42, true}
got, ok := toStringSlice(input)
if !ok {
t.Fatal("expected ok for mixed types")
}
if got[0] != "hello" || got[1] != "42" || got[2] != "true" {
t.Fatalf("unexpected values: %#v", got)
}
}
func TestToStringSliceFromNonSlice(t *testing.T) {
_, ok := toStringSlice("not a slice")
if ok {
t.Fatal("expected not ok for string input")
}
}
func TestToStringSliceFromNil(t *testing.T) {
_, ok := toStringSlice(nil)
if ok {
t.Fatal("expected not ok for nil input")
}
}
func TestToStringSliceEmpty(t *testing.T) {
got, ok := toStringSlice([]any{})
if !ok {
t.Fatal("expected ok for empty slice")
}
if len(got) != 0 {
t.Fatalf("expected empty result, got %#v", got)
}
}
func TestToStringSliceTrimsWhitespace(t *testing.T) {
got, ok := toStringSlice([]any{" hello ", " world "})
if !ok {
t.Fatal("expected ok")
}
if got[0] != "hello" || got[1] != "world" {
t.Fatalf("expected trimmed values, got %#v", got)
}
}
// ─── toAccount edge cases ────────────────────────────────────────────
func TestToAccountAllFields(t *testing.T) {
acc := toAccount(map[string]any{
"email": "user@test.com",
"mobile": "13800138000",
"password": "secret",
"token": "tok123",
})
if acc.Email != "user@test.com" {
t.Fatalf("unexpected email: %q", acc.Email)
}
if acc.Mobile != "13800138000" {
t.Fatalf("unexpected mobile: %q", acc.Mobile)
}
if acc.Password != "secret" {
t.Fatalf("unexpected password: %q", acc.Password)
}
if acc.Token != "tok123" {
t.Fatalf("unexpected token: %q", acc.Token)
}
}
func TestToAccountNumericValues(t *testing.T) {
acc := toAccount(map[string]any{
"email": 12345,
})
if acc.Email != "12345" {
t.Fatalf("expected numeric converted to string, got %q", acc.Email)
}
}
// ─── fieldString edge cases ──────────────────────────────────────────
func TestFieldStringNonString(t *testing.T) {
got := fieldString(map[string]any{"key": 42}, "key")
if got != "42" {
t.Fatalf("expected '42' for int, got %q", got)
}
}
func TestFieldStringBool(t *testing.T) {
got := fieldString(map[string]any{"key": true}, "key")
if got != "true" {
t.Fatalf("expected 'true', got %q", got)
}
}
func TestFieldStringWhitespace(t *testing.T) {
got := fieldString(map[string]any{"key": " hello "}, "key")
if got != "hello" {
t.Fatalf("expected trimmed 'hello', got %q", got)
}
}
// ─── statusOr ────────────────────────────────────────────────────────
func TestStatusOrZeroReturnsDefault(t *testing.T) {
if got := statusOr(0, http.StatusOK); got != http.StatusOK {
t.Fatalf("expected %d, got %d", http.StatusOK, got)
}
}
func TestStatusOrNonZeroReturnsValue(t *testing.T) {
if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest {
t.Fatalf("expected %d, got %d", http.StatusBadRequest, got)
}
}

View File

@@ -0,0 +1,375 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
"ds2api/internal/account"
"ds2api/internal/config"
)
// ─── extractCallerToken edge cases ───────────────────────────────────
func TestExtractCallerTokenBearerPrefix(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer my-token")
if got := extractCallerToken(req); got != "my-token" {
t.Fatalf("expected my-token, got %q", got)
}
}
func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "BEARER My-Token")
if got := extractCallerToken(req); got != "My-Token" {
t.Fatalf("expected My-Token, got %q", got)
}
}
func TestExtractCallerTokenBearerEmpty(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer ")
if got := extractCallerToken(req); got != "" {
t.Fatalf("expected empty for 'Bearer ', got %q", got)
}
}
func TestExtractCallerTokenXAPIKey(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("x-api-key", "x-api-key-token")
if got := extractCallerToken(req); got != "x-api-key-token" {
t.Fatalf("expected x-api-key-token, got %q", got)
}
}
func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer bearer-token")
req.Header.Set("x-api-key", "x-api-key-token")
if got := extractCallerToken(req); got != "bearer-token" {
t.Fatalf("expected bearer-token, got %q", got)
}
}
func TestExtractCallerTokenMissingHeaders(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
if got := extractCallerToken(req); got != "" {
t.Fatalf("expected empty for missing headers, got %q", got)
}
}
func TestExtractCallerTokenNonBearerAuth(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Basic abc123")
if got := extractCallerToken(req); got != "" {
t.Fatalf("expected empty for Basic auth, got %q", got)
}
}
// ─── Context helpers ─────────────────────────────────────────────────
func TestWithAuthAndFromContext(t *testing.T) {
a := &RequestAuth{DeepSeekToken: "test-token"}
ctx := WithAuth(context.Background(), a)
got, ok := FromContext(ctx)
if !ok || got.DeepSeekToken != "test-token" {
t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken)
}
}
func TestFromContextMissing(t *testing.T) {
_, ok := FromContext(context.Background())
if ok {
t.Fatal("expected not ok from empty context")
}
}
// ─── RefreshToken edge cases ─────────────────────────────────────────
func TestRefreshTokenNotConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false, resolver: r}
if r.RefreshToken(context.Background(), a) {
t.Fatal("expected false for non-config token")
}
}
func TestRefreshTokenEmptyAccountID(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r}
if r.RefreshToken(context.Background(), a) {
t.Fatal("expected false for empty account ID")
}
}
func TestRefreshTokenSuccess(t *testing.T) {
r := newTestResolver(t)
// First acquire an account
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
if !r.RefreshToken(context.Background(), a) {
t.Fatal("expected refresh to succeed")
}
if a.DeepSeekToken != "fresh-token" {
t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken)
}
}
// ─── MarkTokenInvalid edge cases ─────────────────────────────────────
func TestMarkTokenInvalidNotConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r}
r.MarkTokenInvalid(a)
// Should not panic, token should be unchanged for non-config
if a.DeepSeekToken != "" {
// Actually it does clear it; that's fine - let's check behavior
}
}
func TestMarkTokenInvalidEmptyAccountID(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r}
r.MarkTokenInvalid(a)
// Should not panic
}
func TestMarkTokenInvalidClearsToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
r.MarkTokenInvalid(a)
if a.DeepSeekToken != "" {
t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken)
}
if a.Account.Token != "" {
t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token)
}
}
// ─── SwitchAccount edge cases ────────────────────────────────────────
func TestSwitchAccountNotConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false, resolver: r}
if r.SwitchAccount(context.Background(), a) {
t.Fatal("expected false for non-config token")
}
}
func TestSwitchAccountNilTriedAccounts(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"acc1@test.com","token":"t1"},
{"email":"acc2@test.com","token":"t2"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "new-token", nil
})
// First acquire
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
oldID := a.AccountID
a.TriedAccounts = nil // test nil initialization in SwitchAccount
if !r.SwitchAccount(context.Background(), a) {
t.Fatal("expected switch to succeed")
}
if a.AccountID == oldID {
t.Fatalf("expected different account after switch")
}
r.Release(a)
}
// ─── Release edge cases ─────────────────────────────────────────────
func TestReleaseNilAuth(t *testing.T) {
r := newTestResolver(t)
r.Release(nil) // should not panic
}
func TestReleaseNonConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false}
r.Release(a) // should not panic
}
func TestReleaseEmptyAccountID(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: true, AccountID: ""}
r.Release(a) // should not panic
}
// ─── JWT edge cases ──────────────────────────────────────────────────
func TestVerifyJWTInvalidFormat(t *testing.T) {
_, err := VerifyJWT("not-a-jwt")
if err == nil {
t.Fatal("expected error for invalid JWT format")
}
}
func TestVerifyJWTInvalidSignature(t *testing.T) {
token, _ := CreateJWT(1)
// Tamper with the signature
parts := splitJWT(token)
if len(parts) == 3 {
tampered := parts[0] + "." + parts[1] + ".invalid_signature"
_, err := VerifyJWT(tampered)
if err == nil {
t.Fatal("expected error for tampered signature")
}
}
}
func TestVerifyJWTExpired(t *testing.T) {
// Create a token with 0 hours expiry - will use default, so we can't easily test
// Instead test with bad payload
_, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid")
if err == nil {
t.Fatal("expected error for expired/invalid JWT")
}
}
func TestCreateJWTDefaultExpiry(t *testing.T) {
token, err := CreateJWT(0) // should use default
if err != nil {
t.Fatalf("create jwt failed: %v", err)
}
_, err = VerifyJWT(token)
if err != nil {
t.Fatalf("verify jwt failed: %v", err)
}
}
// ─── VerifyAdminRequest edge cases ───────────────────────────────────
func TestVerifyAdminRequestNoHeader(t *testing.T) {
req, _ := http.NewRequest("GET", "/admin/config", nil)
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for missing auth")
}
}
func TestVerifyAdminRequestEmptyBearer(t *testing.T) {
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Bearer ")
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for empty bearer")
}
}
func TestVerifyAdminRequestWithAdminKey(t *testing.T) {
t.Setenv("DS2API_ADMIN_KEY", "test-admin-key")
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Bearer test-admin-key")
if err := VerifyAdminRequest(req); err != nil {
t.Fatalf("expected admin key accepted: %v", err)
}
}
func TestVerifyAdminRequestInvalidCredentials(t *testing.T) {
t.Setenv("DS2API_ADMIN_KEY", "correct-key")
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Bearer wrong-key")
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for wrong key")
}
}
func TestVerifyAdminRequestBasicAuth(t *testing.T) {
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Basic abc123")
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for Basic auth")
}
}
// ─── Determine with login failure ────────────────────────────────────
func TestDetermineWithLoginFailure(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[{"email":"acc@test.com","password":"pwd"}]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "", errors.New("login failed")
})
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
_, err := r.Determine(req)
if err == nil {
t.Fatal("expected error when login fails")
}
}
// ─── Determine with target account ───────────────────────────────────
func TestDetermineWithTargetAccount(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"acc1@test.com","token":"t1"},
{"email":"acc2@test.com","token":"t2"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "fresh-token", nil
})
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
req.Header.Set("X-Ds2-Target-Account", "acc2@test.com")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
if a.AccountID != "acc2@test.com" {
t.Fatalf("expected target account acc2, got %q", a.AccountID)
}
}
// helper
func splitJWT(token string) []string {
result := make([]string, 0, 3)
start := 0
count := 0
for i := 0; i < len(token); i++ {
if token[i] == '.' {
result = append(result, token[start:i])
start = i + 1
count++
}
}
result = append(result, token[start:])
return result
}

View File

@@ -2,6 +2,8 @@ package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"net/http"
"strings"
@@ -22,6 +24,7 @@ var (
type RequestAuth struct {
UseConfigToken bool
DeepSeekToken string
CallerID string
AccountID string
Account config.Account
TriedAccounts map[string]bool
@@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
if callerKey == "" {
return nil, ErrUnauthorized
}
callerID := callerTokenID(callerKey)
ctx := req.Context()
if !r.Store.HasAPIKey(callerKey) {
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
return &RequestAuth{
UseConfigToken: false,
DeepSeekToken: callerKey,
CallerID: callerID,
resolver: r,
TriedAccounts: map[string]bool{},
}, nil
}
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
@@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
}
a := &RequestAuth{
UseConfigToken: true,
CallerID: callerID,
AccountID: acc.Identifier(),
Account: acc,
TriedAccounts: map[string]bool{},
@@ -72,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
return a, nil
}
// DetermineCaller resolves caller identity without acquiring any pooled account.
// Use this for local-cache lookup routes that only need tenant isolation.
func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) {
callerKey := extractCallerToken(req)
if callerKey == "" {
return nil, ErrUnauthorized
}
callerID := callerTokenID(callerKey)
a := &RequestAuth{
UseConfigToken: false,
CallerID: callerID,
resolver: r,
TriedAccounts: map[string]bool{},
}
if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) {
a.DeepSeekToken = callerKey
}
return a, nil
}
func WithAuth(ctx context.Context, a *RequestAuth) context.Context {
return context.WithValue(ctx, authCtxKey, a)
}
@@ -158,3 +189,12 @@ func extractCallerToken(req *http.Request) string {
}
return strings.TrimSpace(req.Header.Get("x-api-key"))
}
func callerTokenID(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return "caller:" + hex.EncodeToString(sum[:8])
}

View File

@@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) {
if auth.DeepSeekToken != "direct-token" {
t.Fatalf("unexpected token: %q", auth.DeepSeekToken)
}
if auth.CallerID == "" {
t.Fatalf("expected caller id to be populated")
}
}
func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
@@ -58,6 +61,44 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
if auth.DeepSeekToken != "account-token" {
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
}
if auth.CallerID == "" {
t.Fatalf("expected caller id to be populated")
}
}
func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
req.Header.Set("x-api-key", "managed-key")
a, err := r.DetermineCaller(req)
if err != nil {
t.Fatalf("determine caller failed: %v", err)
}
if a.CallerID == "" {
t.Fatalf("expected caller id to be populated")
}
if a.UseConfigToken {
t.Fatalf("expected no config-token lease for caller-only auth")
}
if a.AccountID != "" {
t.Fatalf("expected empty account id, got %q", a.AccountID)
}
}
func TestCallerTokenIDStable(t *testing.T) {
a := callerTokenID("token-a")
b := callerTokenID("token-a")
c := callerTokenID("token-b")
if a == "" || b == "" || c == "" {
t.Fatalf("expected non-empty caller ids")
}
if a != b {
t.Fatalf("expected stable caller id, got %q and %q", a, b)
}
if a == c {
t.Fatalf("expected different caller id for different tokens")
}
}
func TestDetermineMissingToken(t *testing.T) {
@@ -72,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDetermineCallerMissingToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
_, err := r.DetermineCaller(req)
if err == nil {
t.Fatal("expected unauthorized error")
}
if err != ErrUnauthorized {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
@@ -61,11 +62,33 @@ type Config struct {
Accounts []Account `json:"accounts,omitempty"`
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
ModelAliases map[string]string `json:"model_aliases,omitempty"`
Compat CompatConfig `json:"compat,omitempty"`
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
Responses ResponsesConfig `json:"responses,omitempty"`
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
AdditionalFields map[string]any `json:"-"`
}
type CompatConfig struct {
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
}
type ToolcallConfig struct {
Mode string `json:"mode,omitempty"`
EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"`
}
type ResponsesConfig struct {
StoreTTLSeconds int `json:"store_ttl_seconds,omitempty"`
}
type EmbeddingsConfig struct {
Provider string `json:"provider,omitempty"`
}
func (c Config) MarshalJSON() ([]byte, error) {
m := map[string]any{}
for k, v := range c.AdditionalFields {
@@ -83,6 +106,21 @@ func (c Config) MarshalJSON() ([]byte, error) {
if len(c.ClaudeModelMap) > 0 {
m["claude_model_mapping"] = c.ClaudeModelMap
}
if len(c.ModelAliases) > 0 {
m["model_aliases"] = c.ModelAliases
}
if c.Compat.WideInputStrictOutput != nil {
m["compat"] = c.Compat
}
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
m["toolcall"] = c.Toolcall
}
if c.Responses.StoreTTLSeconds > 0 {
m["responses"] = c.Responses
}
if strings.TrimSpace(c.Embeddings.Provider) != "" {
m["embeddings"] = c.Embeddings
}
if c.VercelSyncHash != "" {
m["_vercel_sync_hash"] = c.VercelSyncHash
}
@@ -101,17 +139,49 @@ func (c *Config) UnmarshalJSON(b []byte) error {
for k, v := range raw {
switch k {
case "keys":
_ = json.Unmarshal(v, &c.Keys)
if err := json.Unmarshal(v, &c.Keys); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "accounts":
_ = json.Unmarshal(v, &c.Accounts)
if err := json.Unmarshal(v, &c.Accounts); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_mapping":
_ = json.Unmarshal(v, &c.ClaudeMapping)
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_model_mapping":
_ = json.Unmarshal(v, &c.ClaudeModelMap)
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "model_aliases":
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "compat":
if err := json.Unmarshal(v, &c.Compat); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "toolcall":
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "responses":
if err := json.Unmarshal(v, &c.Responses); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "embeddings":
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_hash":
_ = json.Unmarshal(v, &c.VercelSyncHash)
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_time":
_ = json.Unmarshal(v, &c.VercelSyncTime)
if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
default:
var anyVal any
if err := json.Unmarshal(v, &anyVal); err == nil {
@@ -124,10 +194,17 @@ func (c *Config) UnmarshalJSON(b []byte) error {
func (c Config) Clone() Config {
clone := Config{
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),
Compat: CompatConfig{
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
},
Toolcall: c.Toolcall,
Responses: c.Responses,
Embeddings: c.Embeddings,
VercelSyncHash: c.VercelSyncHash,
VercelSyncTime: c.VercelSyncTime,
AdditionalFields: map[string]any{},
@@ -149,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string {
return out
}
func cloneBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
v := *in
return &v
}
type Store struct {
mu sync.RWMutex
cfg Config
@@ -233,30 +318,94 @@ func loadConfig() (Config, bool, error) {
content, err := os.ReadFile(ConfigPath())
if err != nil {
if IsVercel() {
// Vercel one-click deploy may start without a writable/present config file.
// Keep an in-memory config so users can bootstrap via WebUI then sync env.
return Config{}, true, nil
}
return Config{}, false, err
}
var cfg Config
if err := json.Unmarshal(content, &cfg); err != nil {
return Config{}, false, err
}
if IsVercel() {
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
return cfg, true, nil
}
return cfg, false, nil
}
func parseConfigString(raw string) (Config, error) {
var cfg Config
if err := json.Unmarshal([]byte(raw), &cfg); err == nil {
return cfg, nil
candidates := []string{raw}
if normalized := normalizeConfigInput(raw); normalized != raw {
candidates = append(candidates, normalized)
}
decoded, err := base64.StdEncoding.DecodeString(raw)
for _, candidate := range candidates {
if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
return cfg, nil
}
}
base64Input := candidates[len(candidates)-1]
decoded, err := decodeConfigBase64(base64Input)
if err != nil {
return Config{}, err
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
}
if err := json.Unmarshal(decoded, &cfg); err != nil {
return Config{}, err
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
}
return cfg, nil
}
func normalizeConfigInput(raw string) string {
normalized := strings.TrimSpace(raw)
if normalized == "" {
return normalized
}
for {
changed := false
if len(normalized) >= 2 {
first := normalized[0]
last := normalized[len(normalized)-1]
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
changed = true
}
}
if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
normalized = strings.TrimSpace(normalized[len("base64:"):])
changed = true
}
if !changed {
break
}
}
return strings.TrimSpace(normalized)
}
func decodeConfigBase64(raw string) ([]byte, error) {
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
var lastErr error
for _, enc := range encodings {
decoded, err := enc.DecodeString(raw)
if err == nil {
return decoded, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("base64 decode failed")
}
func (s *Store) Snapshot() Config {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -413,3 +562,62 @@ func (s *Store) ClaudeMapping() map[string]string {
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}
func (s *Store) ModelAliases() map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
out := DefaultModelAliases()
for k, v := range s.cfg.ModelAliases {
key := strings.TrimSpace(lower(k))
val := strings.TrimSpace(lower(v))
if key == "" || val == "" {
continue
}
out[key] = val
}
return out
}
func (s *Store) CompatWideInputStrictOutput() bool {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Compat.WideInputStrictOutput == nil {
return true
}
return *s.cfg.Compat.WideInputStrictOutput
}
func (s *Store) ToolcallMode() string {
s.mu.RLock()
defer s.mu.RUnlock()
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
if mode == "" {
return "feature_match"
}
return mode
}
func (s *Store) ToolcallEarlyEmitConfidence() string {
s.mu.RLock()
defer s.mu.RUnlock()
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
if level == "" {
return "high"
}
return level
}
func (s *Store) ResponsesStoreTTLSeconds() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Responses.StoreTTLSeconds > 0 {
return s.cfg.Responses.StoreTTLSeconds
}
return 900
}
func (s *Store) EmbeddingsProvider() string {
s.mu.RLock()
defer s.mu.RUnlock()
return strings.TrimSpace(s.cfg.Embeddings.Provider)
}

View File

@@ -0,0 +1,478 @@
package config
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
)
// ─── GetModelConfig edge cases ───────────────────────────────────────
func TestGetModelConfigDeepSeekChat(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-chat")
if !ok {
t.Fatal("expected ok for deepseek-chat")
}
if thinking || search {
t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigDeepSeekReasoner(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-reasoner")
if !ok {
t.Fatal("expected ok for deepseek-reasoner")
}
if !thinking || search {
t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigDeepSeekChatSearch(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-chat-search")
if !ok {
t.Fatal("expected ok for deepseek-chat-search")
}
if thinking || !search {
t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-reasoner-search")
if !ok {
t.Fatal("expected ok for deepseek-reasoner-search")
}
if !thinking || !search {
t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigCaseInsensitive(t *testing.T) {
thinking, search, ok := GetModelConfig("DeepSeek-Chat")
if !ok {
t.Fatal("expected ok for case-insensitive deepseek-chat")
}
if thinking || search {
t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat")
}
}
func TestGetModelConfigUnknownModel(t *testing.T) {
_, _, ok := GetModelConfig("gpt-4")
if ok {
t.Fatal("expected not ok for unknown model")
}
}
func TestGetModelConfigEmpty(t *testing.T) {
_, _, ok := GetModelConfig("")
if ok {
t.Fatal("expected not ok for empty model")
}
}
// ─── lower function ──────────────────────────────────────────────────
func TestLowerFunction(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"Hello", "hello"},
{"ALLCAPS", "allcaps"},
{"already-lower", "already-lower"},
{"Mixed-CASE-123", "mixed-case-123"},
{"", ""},
}
for _, tc := range tests {
got := lower(tc.input)
if got != tc.expected {
t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected)
}
}
}
// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ────────────────────
func TestConfigJSONRoundtrip(t *testing.T) {
cfg := Config{
Keys: []string{"key1", "key2"},
Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}},
ClaudeMapping: map[string]string{
"fast": "deepseek-chat",
"slow": "deepseek-reasoner",
},
VercelSyncHash: "hash123",
VercelSyncTime: 1234567890,
AdditionalFields: map[string]any{
"custom_field": "custom_value",
},
}
data, err := cfg.MarshalJSON()
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var decoded Config
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" {
t.Fatalf("unexpected keys: %#v", decoded.Keys)
}
if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" {
t.Fatalf("unexpected accounts: %#v", decoded.Accounts)
}
if decoded.ClaudeMapping["fast"] != "deepseek-chat" {
t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping)
}
if decoded.VercelSyncHash != "hash123" {
t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash)
}
if decoded.AdditionalFields["custom_field"] != "custom_value" {
t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields)
}
}
func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) {
raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}`
var cfg Config
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if cfg.AdditionalFields["my_custom_field"] != "hello" {
t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields)
}
// number_field should also be preserved
if cfg.AdditionalFields["number_field"] != float64(42) {
t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"])
}
}
// ─── Config.Clone ────────────────────────────────────────────────────
func TestConfigCloneIsDeepCopy(t *testing.T) {
cfg := Config{
Keys: []string{"key1"},
Accounts: []Account{{Email: "user@test.com", Token: "token"}},
ClaudeMapping: map[string]string{
"fast": "deepseek-chat",
},
AdditionalFields: map[string]any{"custom": "value"},
}
cloned := cfg.Clone()
// Modify original
cfg.Keys[0] = "modified"
cfg.Accounts[0].Email = "modified@test.com"
cfg.ClaudeMapping["fast"] = "modified-model"
// Cloned should not be affected
if cloned.Keys[0] != "key1" {
t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys)
}
if cloned.Accounts[0].Email != "user@test.com" {
t.Fatalf("clone accounts was affected: %#v", cloned.Accounts)
}
if cloned.ClaudeMapping["fast"] != "deepseek-chat" {
t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping)
}
}
func TestConfigCloneNilMaps(t *testing.T) {
cfg := Config{
Keys: []string{"k"},
Accounts: nil,
}
cloned := cfg.Clone()
if len(cloned.Keys) != 1 {
t.Fatalf("unexpected keys length: %d", len(cloned.Keys))
}
if cloned.Accounts != nil {
t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts)
}
}
// ─── Account.Identifier edge cases ───────────────────────────────────
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
acc := Account{Mobile: "13800138000", Token: "tok"}
if acc.Identifier() != "13800138000" {
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
}
}
func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) {
acc := Account{Email: "user@test.com", Mobile: "13800138000"}
if acc.Identifier() != "user@test.com" {
t.Fatalf("expected email identifier, got %q", acc.Identifier())
}
}
func TestAccountIdentifierEmptyAccount(t *testing.T) {
acc := Account{}
if acc.Identifier() != "" {
t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier())
}
}
// ─── normalizeConfigInput ────────────────────────────────────────────
func TestNormalizeConfigInputStripsQuotes(t *testing.T) {
got := normalizeConfigInput(`"base64:abc"`)
if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) {
t.Fatalf("expected quotes stripped, got %q", got)
}
}
func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) {
got := normalizeConfigInput("'some-value'")
if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") {
t.Fatalf("expected single quotes stripped, got %q", got)
}
}
func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) {
got := normalizeConfigInput(" hello ")
if got != "hello" {
t.Fatalf("expected trimmed, got %q", got)
}
}
// ─── parseConfigString edge cases ────────────────────────────────────
func TestParseConfigStringPlainJSON(t *testing.T) {
cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" {
t.Fatalf("unexpected keys: %#v", cfg.Keys)
}
}
func TestParseConfigStringBase64Prefix(t *testing.T) {
rawJSON := `{"keys":["base64-key"],"accounts":[]}`
b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON))
cfg, err := parseConfigString("base64:" + b64)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" {
t.Fatalf("unexpected keys: %#v", cfg.Keys)
}
}
func TestParseConfigStringInvalidBase64(t *testing.T) {
_, err := parseConfigString("base64:!!!invalid!!!")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestParseConfigStringEmptyString(t *testing.T) {
_, err := parseConfigString("")
if err == nil {
t.Fatal("expected error for empty string")
}
}
// ─── Store methods ───────────────────────────────────────────────────
func TestStoreSnapshotReturnsClone(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`)
store := LoadStore()
snap := store.Snapshot()
snap.Keys[0] = "modified"
if store.Keys()[0] != "k1" {
t.Fatal("snapshot modification should not affect store")
}
}
func TestStoreHasAPIKeyMultipleKeys(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`)
store := LoadStore()
if !store.HasAPIKey("key1") {
t.Fatal("expected key1 found")
}
if !store.HasAPIKey("key2") {
t.Fatal("expected key2 found")
}
if !store.HasAPIKey("key3") {
t.Fatal("expected key3 found")
}
if store.HasAPIKey("nonexistent") {
t.Fatal("expected nonexistent key not found")
}
}
func TestStoreFindAccountNotFound(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`)
store := LoadStore()
_, ok := store.FindAccount("nonexistent@test.com")
if ok {
t.Fatal("expected account not found")
}
}
func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
if !store.CompatWideInputStrictOutput() {
t.Fatal("expected default wide_input_strict_output=true when unset")
}
}
func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`)
store := LoadStore()
if store.CompatWideInputStrictOutput() {
t.Fatal("expected wide_input_strict_output=false when explicitly configured")
}
snap := store.Snapshot()
data, err := snap.MarshalJSON()
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out map[string]any
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("decode failed: %v", err)
}
rawCompat, ok := out["compat"].(map[string]any)
if !ok {
t.Fatalf("expected compat in marshaled output, got %#v", out)
}
if rawCompat["wide_input_strict_output"] != false {
t.Fatalf("expected explicit false in compat, got %#v", rawCompat)
}
}
func TestStoreIsEnvBacked(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
if !store.IsEnvBacked() {
t.Fatal("expected env-backed store")
}
}
func TestStoreReplace(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
newCfg := Config{
Keys: []string{"new-key"},
Accounts: []Account{{Email: "new@test.com"}},
}
if err := store.Replace(newCfg); err != nil {
t.Fatalf("replace error: %v", err)
}
if !store.HasAPIKey("new-key") {
t.Fatal("expected new key after replace")
}
if store.HasAPIKey("k1") {
t.Fatal("expected old key removed after replace")
}
}
func TestStoreUpdate(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
err := store.Update(func(cfg *Config) error {
cfg.Keys = append(cfg.Keys, "k2")
return nil
})
if err != nil {
t.Fatalf("update error: %v", err)
}
if !store.HasAPIKey("k2") {
t.Fatal("expected k2 after update")
}
}
func TestStoreClaudeMapping(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
store := LoadStore()
mapping := store.ClaudeMapping()
if mapping["fast"] != "deepseek-chat" {
t.Fatalf("unexpected fast mapping: %q", mapping["fast"])
}
if mapping["slow"] != "deepseek-reasoner" {
t.Fatalf("unexpected slow mapping: %q", mapping["slow"])
}
}
func TestStoreClaudeMappingEmpty(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
store := LoadStore()
mapping := store.ClaudeMapping()
// Even without config mapping, there are defaults
if mapping == nil {
t.Fatal("expected non-nil mapping (may contain defaults)")
}
}
func TestStoreSetVercelSync(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
store := LoadStore()
if err := store.SetVercelSync("hash123", 1234567890); err != nil {
t.Fatalf("setVercelSync error: %v", err)
}
snap := store.Snapshot()
if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 {
t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime)
}
}
func TestStoreExportJSONAndBase64(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`)
store := LoadStore()
jsonStr, b64Str, err := store.ExportJSONAndBase64()
if err != nil {
t.Fatalf("export error: %v", err)
}
if !strings.Contains(jsonStr, "export-key") {
t.Fatalf("expected JSON to contain key: %q", jsonStr)
}
decoded, err := base64.StdEncoding.DecodeString(b64Str)
if err != nil {
t.Fatalf("base64 decode error: %v", err)
}
if !strings.Contains(string(decoded), "export-key") {
t.Fatalf("expected base64-decoded to contain key: %q", string(decoded))
}
}
// ─── OpenAIModelsResponse / ClaudeModelsResponse ─────────────────────
func TestOpenAIModelsResponse(t *testing.T) {
resp := OpenAIModelsResponse()
if resp["object"] != "list" {
t.Fatalf("unexpected object: %v", resp["object"])
}
data, ok := resp["data"].([]ModelInfo)
if !ok {
t.Fatalf("unexpected data type: %T", resp["data"])
}
if len(data) == 0 {
t.Fatal("expected non-empty models list")
}
}
func TestClaudeModelsResponse(t *testing.T) {
resp := ClaudeModelsResponse()
if resp["object"] != "list" {
t.Fatalf("unexpected object: %v", resp["object"])
}
data, ok := resp["data"].([]ModelInfo)
if !ok {
t.Fatalf("unexpected data type: %T", resp["data"])
}
if len(data) == 0 {
t.Fatal("expected non-empty models list")
}
}

View File

@@ -1,6 +1,7 @@
package config
import (
"encoding/base64"
"strings"
"testing"
)
@@ -70,3 +71,53 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T)
t.Fatalf("expected find by old identifier alias")
}
}
func TestLoadStoreRejectsInvalidFieldType(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":"not-array","accounts":[]}`)
store := LoadStore()
if len(store.Keys()) != 0 || len(store.Accounts()) != 0 {
t.Fatalf("expected empty store when config type is invalid")
}
}
func TestParseConfigStringSupportsQuotedBase64Prefix(t *testing.T) {
rawJSON := `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}`
b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON))
cfg, err := parseConfigString(`"base64:` + b64 + `"`)
if err != nil {
t.Fatalf("unexpected parse error: %v", err)
}
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" {
t.Fatalf("unexpected keys: %#v", cfg.Keys)
}
}
func TestParseConfigStringSupportsRawURLBase64(t *testing.T) {
rawJSON := `{"keys":["k-url"],"accounts":[]}`
b64 := base64.RawURLEncoding.EncodeToString([]byte(rawJSON))
cfg, err := parseConfigString(b64)
if err != nil {
t.Fatalf("unexpected parse error: %v", err)
}
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k-url" {
t.Fatalf("unexpected keys: %#v", cfg.Keys)
}
}
func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) {
t.Setenv("VERCEL", "1")
t.Setenv("DS2API_CONFIG_JSON", "")
t.Setenv("CONFIG_JSON", "")
t.Setenv("DS2API_CONFIG_PATH", "testdata/does-not-exist.json")
cfg, fromEnv, err := loadConfig()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !fromEnv {
t.Fatalf("expected fromEnv=true for vercel fallback")
}
if len(cfg.Keys) != 0 || len(cfg.Accounts) != 0 {
t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts))
}
}

View File

@@ -0,0 +1,44 @@
package config
import "testing"
func TestResolveModelDirectDeepSeek(t *testing.T) {
got, ok := ResolveModel(nil, "deepseek-chat")
if !ok || got != "deepseek-chat" {
t.Fatalf("expected deepseek-chat, got ok=%v model=%q", ok, got)
}
}
func TestResolveModelAlias(t *testing.T) {
got, ok := ResolveModel(nil, "gpt-4.1")
if !ok || got != "deepseek-chat" {
t.Fatalf("expected alias gpt-4.1 -> deepseek-chat, got ok=%v model=%q", ok, got)
}
}
func TestResolveModelHeuristicReasoner(t *testing.T) {
got, ok := ResolveModel(nil, "o3-super")
if !ok || got != "deepseek-reasoner" {
t.Fatalf("expected heuristic reasoner, got ok=%v model=%q", ok, got)
}
}
func TestResolveModelUnknown(t *testing.T) {
_, ok := ResolveModel(nil, "totally-custom-model")
if ok {
t.Fatal("expected unknown model to fail resolve")
}
}
func TestClaudeModelsResponsePaginationFields(t *testing.T) {
resp := ClaudeModelsResponse()
if _, ok := resp["first_id"]; !ok {
t.Fatalf("expected first_id in response: %#v", resp)
}
if _, ok := resp["last_id"]; !ok {
t.Fatalf("expected last_id in response: %#v", resp)
}
if _, ok := resp["has_more"]; !ok {
t.Fatalf("expected has_more in response: %#v", resp)
}
}

View File

@@ -1,5 +1,7 @@
package config
import "strings"
type ModelInfo struct {
ID string `json:"id"`
Object string `json:"object"`
@@ -71,6 +73,91 @@ func GetModelConfig(model string) (thinking bool, search bool, ok bool) {
}
}
func IsSupportedDeepSeekModel(model string) bool {
_, _, ok := GetModelConfig(model)
return ok
}
func DefaultModelAliases() map[string]string {
return map[string]string{
"gpt-4o": "deepseek-chat",
"gpt-4.1": "deepseek-chat",
"gpt-4.1-mini": "deepseek-chat",
"gpt-4.1-nano": "deepseek-chat",
"gpt-5": "deepseek-chat",
"gpt-5-mini": "deepseek-chat",
"gpt-5-codex": "deepseek-reasoner",
"o1": "deepseek-reasoner",
"o1-mini": "deepseek-reasoner",
"o3": "deepseek-reasoner",
"o3-mini": "deepseek-reasoner",
"claude-sonnet-4-5": "deepseek-chat",
"claude-haiku-4-5": "deepseek-chat",
"claude-opus-4-6": "deepseek-reasoner",
"claude-3-5-sonnet": "deepseek-chat",
"claude-3-5-haiku": "deepseek-chat",
"claude-3-opus": "deepseek-reasoner",
"gemini-2.5-pro": "deepseek-chat",
"gemini-2.5-flash": "deepseek-chat",
"llama-3.1-70b-instruct": "deepseek-chat",
"qwen-max": "deepseek-chat",
}
}
func ResolveModel(store *Store, requested string) (string, bool) {
model := lower(strings.TrimSpace(requested))
if model == "" {
return "", false
}
if IsSupportedDeepSeekModel(model) {
return model, true
}
aliases := DefaultModelAliases()
if store != nil {
for k, v := range store.ModelAliases() {
aliases[lower(strings.TrimSpace(k))] = lower(strings.TrimSpace(v))
}
}
if mapped, ok := aliases[model]; ok && IsSupportedDeepSeekModel(mapped) {
return mapped, true
}
if strings.HasPrefix(model, "deepseek-") {
return "", false
}
knownFamily := false
for _, prefix := range []string{
"gpt-", "o1", "o3", "claude-", "gemini-", "llama-", "qwen-", "mistral-", "command-",
} {
if strings.HasPrefix(model, prefix) {
knownFamily = true
break
}
}
if !knownFamily {
return "", false
}
useReasoner := strings.Contains(model, "reason") ||
strings.Contains(model, "reasoner") ||
strings.HasPrefix(model, "o1") ||
strings.HasPrefix(model, "o3") ||
strings.Contains(model, "opus") ||
strings.Contains(model, "r1")
useSearch := strings.Contains(model, "search")
switch {
case useReasoner && useSearch:
return "deepseek-reasoner-search", true
case useReasoner:
return "deepseek-reasoner", true
case useSearch:
return "deepseek-chat-search", true
default:
return "deepseek-chat", true
}
}
func lower(s string) string {
b := []byte(s)
for i, c := range b {
@@ -85,6 +172,28 @@ func OpenAIModelsResponse() map[string]any {
return map[string]any{"object": "list", "data": DeepSeekModels}
}
func ClaudeModelsResponse() map[string]any {
return map[string]any{"object": "list", "data": ClaudeModels}
func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) {
canonical, ok := ResolveModel(store, id)
if !ok {
return ModelInfo{}, false
}
for _, model := range DeepSeekModels {
if model.ID == canonical {
return model, true
}
}
return ModelInfo{}, false
}
func ClaudeModelsResponse() map[string]any {
resp := map[string]any{"object": "list", "data": ClaudeModels}
if len(ClaudeModels) > 0 {
resp["first_id"] = ClaudeModels[0].ID
resp["last_id"] = ClaudeModels[len(ClaudeModels)-1].ID
} else {
resp["first_id"] = nil
resp["last_id"] = nil
}
resp["has_more"] = false
return resp
}

View File

@@ -0,0 +1,165 @@
package deepseek
import (
"context"
"testing"
)
// ─── toFloat64 edge cases ────────────────────────────────────────────
func TestToFloat64FromFloat64(t *testing.T) {
if got := toFloat64(float64(3.14), 0); got != 3.14 {
t.Fatalf("expected 3.14, got %f", got)
}
}
func TestToFloat64FromInt(t *testing.T) {
if got := toFloat64(42, 0); got != 42.0 {
t.Fatalf("expected 42.0, got %f", got)
}
}
func TestToFloat64FromInt64(t *testing.T) {
if got := toFloat64(int64(100), 0); got != 100.0 {
t.Fatalf("expected 100.0, got %f", got)
}
}
func TestToFloat64FromStringDefault(t *testing.T) {
if got := toFloat64("42", 99.0); got != 99.0 {
t.Fatalf("expected default 99.0, got %f", got)
}
}
func TestToFloat64FromNilDefault(t *testing.T) {
if got := toFloat64(nil, 5.5); got != 5.5 {
t.Fatalf("expected default 5.5, got %f", got)
}
}
func TestToFloat64FromBoolDefault(t *testing.T) {
if got := toFloat64(true, 1.0); got != 1.0 {
t.Fatalf("expected default 1.0, got %f", got)
}
}
// ─── toInt64 edge cases ──────────────────────────────────────────────
func TestToInt64FromFloat64(t *testing.T) {
if got := toInt64(float64(42.9), 0); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestToInt64FromInt(t *testing.T) {
if got := toInt64(42, 0); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestToInt64FromInt64(t *testing.T) {
if got := toInt64(int64(100), 0); got != 100 {
t.Fatalf("expected 100, got %d", got)
}
}
func TestToInt64FromStringDefault(t *testing.T) {
if got := toInt64("42", 99); got != 99 {
t.Fatalf("expected default 99, got %d", got)
}
}
func TestToInt64FromNilDefault(t *testing.T) {
if got := toInt64(nil, 7); got != 7 {
t.Fatalf("expected default 7, got %d", got)
}
}
// ─── BuildPowHeader edge cases ───────────────────────────────────────
func TestBuildPowHeaderBasicChallenge(t *testing.T) {
challenge := map[string]any{
"algorithm": "DeepSeekHashV1",
"challenge": "abc123",
"salt": "salt456",
"signature": "sig789",
"target_path": "/path",
}
result, err := BuildPowHeader(challenge, 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == "" {
t.Fatal("expected non-empty result")
}
}
func TestBuildPowHeaderEmptyChallenge(t *testing.T) {
result, err := BuildPowHeader(map[string]any{}, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should produce a base64 encoded JSON with nil values
if result == "" {
t.Fatal("expected non-empty result for empty challenge")
}
}
// ─── PowSolver pool size ─────────────────────────────────────────────
func TestPowPoolSizeFromEnvDefault(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "")
got := powPoolSizeFromEnv()
if got < 1 {
t.Fatalf("expected positive default pool size, got %d", got)
}
}
func TestPowPoolSizeFromEnvInvalid(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "abc")
got := powPoolSizeFromEnv()
if got < 1 {
t.Fatalf("expected positive default for invalid, got %d", got)
}
}
func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "5")
got := powPoolSizeFromEnv()
if got != 5 {
t.Fatalf("expected 5, got %d", got)
}
}
// ─── NewClient ───────────────────────────────────────────────────────
func TestNewClientInitialState(t *testing.T) {
client := NewClient(nil, nil)
if client.powSolver == nil {
t.Fatal("expected powSolver to be initialized")
}
}
func TestNewClientPreloadPowIdempotent(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "1")
client := NewClient(nil, nil)
if err := client.PreloadPow(context.Background()); err != nil {
t.Fatalf("first preload failed: %v", err)
}
if err := client.PreloadPow(context.Background()); err != nil {
t.Fatalf("second preload failed: %v", err)
}
}
// ─── PowSolver init and module pool ──────────────────────────────────
func TestPowSolverPoolSizeMatchesEnv(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "2")
solver := NewPowSolver("test.wasm")
if err := solver.init(context.Background()); err != nil {
t.Fatalf("init failed: %v", err)
}
if cap(solver.pool) != 2 {
t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool))
}
}

View File

@@ -92,7 +92,7 @@ func cors(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return

View File

@@ -0,0 +1,140 @@
package sse
import (
"io"
"net/http"
"strings"
"testing"
)
// ─── CollectStream edge cases ────────────────────────────────────────
func makeHTTPResponse(body string) *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func TestCollectStreamEmpty(t *testing.T) {
resp := makeHTTPResponse("")
result := CollectStream(resp, false, false)
if result.Text != "" || result.Thinking != "" {
t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking)
}
}
func TestCollectStreamTextOnly(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\" World\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, false, false)
if result.Text != "Hello World" {
t.Fatalf("expected 'Hello World', got %q", result.Text)
}
if result.Thinking != "" {
t.Fatalf("expected no thinking, got %q", result.Thinking)
}
}
func TestCollectStreamThinkingAndText(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "Thinking..." {
t.Fatalf("expected 'Thinking...', got %q", result.Thinking)
}
if result.Text != "Answer" {
t.Fatalf("expected 'Answer', got %q", result.Text)
}
}
func TestCollectStreamOnlyThinking(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "Only thinking" {
t.Fatalf("expected 'Only thinking', got %q", result.Thinking)
}
if result.Text != "" {
t.Fatalf("expected empty text, got %q", result.Text)
}
}
func TestCollectStreamSkipsInvalidLines(t *testing.T) {
resp := makeHTTPResponse(
"event: comment\n" +
"data: invalid_json\n" +
"data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, false, false)
if result.Text != "valid" {
t.Fatalf("expected 'valid', got %q", result.Text)
}
}
func TestCollectStreamWithFragments(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" +
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "Think" {
t.Fatalf("expected 'Think' thinking, got %q", result.Thinking)
}
if result.Text != "Done" {
t.Fatalf("expected 'Done' text, got %q", result.Text)
}
}
func TestCollectStreamWithCitation(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\" more\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, false, false)
// CollectStream does NOT filter citations (that's done by the adapters)
// So citations are passed through as-is
if !strings.Contains(result.Text, "[citation:1]") {
t.Fatalf("expected citations to be passed through, got %q", result.Text)
}
if result.Text != "Hello[citation:1] cited text more" {
t.Fatalf("expected full text with citation, got %q", result.Text)
}
}
func TestCollectStreamMultipleThinkingChunks(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" +
"data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "part1 part2" {
t.Fatalf("expected 'part1 part2', got %q", result.Thinking)
}
}
func TestCollectStreamStatusFinished(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
"data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n",
)
result := CollectStream(resp, false, false)
if result.Text != "Hello" {
t.Fatalf("expected 'Hello', got %q", result.Text)
}
}

View File

@@ -0,0 +1,70 @@
package sse
import "testing"
func TestParseDeepSeekContentLineNotParsed(t *testing.T) {
res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text")
if res.Parsed {
t.Fatal("expected not parsed")
}
if res.NextType != "text" {
t.Fatalf("expected nextType preserved, got %q", res.NextType)
}
}
func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking")
if !res.Parsed || res.Stop {
t.Fatalf("expected parsed non-stop: %#v", res)
}
if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" {
t.Fatalf("unexpected parts: %#v", res.Parts)
}
}
func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) {
res := ParseDeepSeekContentLine(
[]byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`),
true, "thinking",
)
if !res.Parsed || res.Stop {
t.Fatalf("expected parsed non-stop: %#v", res)
}
if res.NextType != "text" {
t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType)
}
}
func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text")
if !res.ContentFilter {
t.Fatal("expected content filter flag")
}
if res.ErrorMessage == "" {
t.Fatal("expected error message on content filter")
}
}
func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text")
if !res.Parsed || !res.Stop {
t.Fatalf("expected parsed stop: %#v", res)
}
if res.ErrorMessage == "" {
t.Fatal("expected non-empty error message")
}
}
func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) {
res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text")
if res.Parsed {
t.Fatal("expected not parsed for broken JSON")
}
}
func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) {
res := ParseDeepSeekContentLine([]byte{}, false, "text")
if res.Parsed {
t.Fatal("expected not parsed for empty bytes")
}
}

View File

@@ -0,0 +1,631 @@
package sse
import "testing"
// ─── ParseDeepSeekSSELine edge cases ─────────────────────────────────
func TestParseDeepSeekSSELineEmptyLine(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte(""))
if ok {
t.Fatal("expected not parsed for empty line")
}
}
func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte("event: message"))
if ok {
t.Fatal("expected not parsed for non-data line")
}
}
func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json"))
if ok {
t.Fatal("expected not parsed for invalid JSON")
}
}
func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte(" "))
if ok {
t.Fatal("expected not parsed for whitespace-only line")
}
}
func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) {
chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `))
if !ok || done {
t.Fatalf("expected parsed chunk for spaced data line")
}
if chunk["v"] != "hello" {
t.Fatalf("unexpected chunk: %#v", chunk)
}
}
// ─── shouldSkipPath edge cases ───────────────────────────────────────
func TestShouldSkipPathQuasiStatus(t *testing.T) {
if !shouldSkipPath("response/quasi_status") {
t.Fatal("expected skip for quasi_status path")
}
}
func TestShouldSkipPathElapsedSecs(t *testing.T) {
if !shouldSkipPath("response/elapsed_secs") {
t.Fatal("expected skip for elapsed_secs path")
}
}
func TestShouldSkipPathTokenUsage(t *testing.T) {
if !shouldSkipPath("response/token_usage") {
t.Fatal("expected skip for token_usage path")
}
}
func TestShouldSkipPathPendingFragment(t *testing.T) {
if !shouldSkipPath("response/pending_fragment") {
t.Fatal("expected skip for pending_fragment path")
}
}
func TestShouldSkipPathConversationMode(t *testing.T) {
if !shouldSkipPath("response/conversation_mode") {
t.Fatal("expected skip for conversation_mode path")
}
}
func TestShouldSkipPathSearchStatus(t *testing.T) {
if !shouldSkipPath("response/search_status") {
t.Fatal("expected skip for search_status path")
}
}
func TestShouldSkipPathFragmentStatus(t *testing.T) {
if !shouldSkipPath("response/fragments/-1/status") {
t.Fatal("expected skip for fragment -1 status")
}
if !shouldSkipPath("response/fragments/-2/status") {
t.Fatal("expected skip for fragment -2 status")
}
if !shouldSkipPath("response/fragments/-3/status") {
t.Fatal("expected skip for fragment -3 status")
}
}
func TestShouldSkipPathRegularContent(t *testing.T) {
if shouldSkipPath("response/content") {
t.Fatal("expected not skip for content path")
}
if shouldSkipPath("response/thinking_content") {
t.Fatal("expected not skip for thinking_content path")
}
}
// ─── ParseSSEChunkForContent edge cases ──────────────────────────────
func TestParseSSEChunkForContentNoVField(t *testing.T) {
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text")
if finished {
t.Fatal("expected not finished")
}
if len(parts) != 0 {
t.Fatalf("expected no parts when v is missing, got %#v", parts)
}
if nextType != "text" {
t.Fatalf("expected type preserved, got %q", nextType)
}
}
func TestParseSSEChunkForContentSkippedPath(t *testing.T) {
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{
"p": "response/token_usage",
"v": "some data",
}, false, "text")
if finished || len(parts) > 0 {
t.Fatalf("expected skipped path to produce no output")
}
if nextType != "text" {
t.Fatalf("expected type preserved for skipped path")
}
}
func TestParseSSEChunkForContentFinishedStatus(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "response/status",
"v": "FINISHED",
}, false, "text")
if !finished {
t.Fatal("expected finished on status FINISHED")
}
if len(parts) != 0 {
t.Fatalf("expected no parts on finished, got %d", len(parts))
}
}
func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "response/status",
"v": "IN_PROGRESS",
}, false, "text")
if finished {
t.Fatal("expected not finished for non-FINISHED status")
}
if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" {
t.Fatalf("expected content for non-FINISHED status, got %#v", parts)
}
}
func TestParseSSEChunkForContentEmptyStringV(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "response/content",
"v": "",
}, false, "text")
if finished {
t.Fatal("expected not finished")
}
if len(parts) != 0 {
t.Fatalf("expected no parts for empty string v, got %#v", parts)
}
}
func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "",
"v": "FINISHED",
}, false, "text")
if !finished {
t.Fatal("expected finished on empty path with FINISHED value")
}
if len(parts) != 0 {
t.Fatalf("expected no parts on finished")
}
}
func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) {
_, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "status",
"v": "FINISHED",
}, false, "text")
if !finished {
t.Fatal("expected finished on status path with FINISHED value")
}
}
func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) {
parts, _, nextType := ParseSSEChunkForContent(map[string]any{
"v": "some thought",
}, true, "thinking")
if len(parts) != 1 || parts[0].Type != "thinking" {
t.Fatalf("expected thinking part on empty path, got %#v", parts)
}
if nextType != "thinking" {
t.Fatalf("expected nextType thinking, got %q", nextType)
}
}
func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) {
parts, _, nextType := ParseSSEChunkForContent(map[string]any{
"v": "text content",
}, true, "text")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text part when currentType=text, got %#v", parts)
}
if nextType != "text" {
t.Fatalf("expected nextType text, got %q", nextType)
}
}
// ─── ParseSSEChunkForContent: fragments path with THINK type ─────────
func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "THINK",
"content": "深入思考...",
},
},
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text")
if finished {
t.Fatal("expected not finished")
}
if nextType != "thinking" {
t.Fatalf("expected nextType thinking, got %q", nextType)
}
if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." {
t.Fatalf("unexpected parts: %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "RESPONSE",
"content": "",
},
},
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking")
if finished {
t.Fatal("expected not finished")
}
if nextType != "text" {
t.Fatalf("expected nextType text, got %q", nextType)
}
if len(parts) != 0 {
t.Fatalf("expected no parts for empty content, got %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "UNKNOWN",
"content": "some text",
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, true, "text")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for unknown fragment type, got %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": "not an array",
}
parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text")
if finished {
t.Fatal("expected not finished")
}
// "not an array" should be treated as string value at the end
if len(parts) != 1 || parts[0].Text != "not an array" {
t.Fatalf("unexpected parts: %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{"string item"},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
// Non-map items in fragment array are skipped; the []any itself is handled later
_ = parts // just checking it doesn't panic
}
// ─── ParseSSEChunkForContent: response path with nested fragment ─────
func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) {
chunk := map[string]any{
"p": "response",
"v": []any{
map[string]any{
"p": "fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "THINKING",
},
},
},
},
}
_, _, nextType := ParseSSEChunkForContent(chunk, true, "text")
if nextType != "thinking" {
t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType)
}
}
func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) {
chunk := map[string]any{
"p": "response",
"v": []any{
map[string]any{
"p": "fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "RESPONSE",
},
},
},
},
}
_, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking")
if nextType != "text" {
t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType)
}
}
// ─── ParseSSEChunkForContent: map value with wrapped response ────────
func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"response": map[string]any{
"fragments": []any{
map[string]any{
"type": "THINK",
"content": "思考...",
},
map[string]any{
"type": "RESPONSE",
"content": "回答...",
},
},
},
},
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text")
if finished {
t.Fatal("expected not finished")
}
if nextType != "text" {
t.Fatalf("expected nextType text after RESPONSE, got %q", nextType)
}
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts)
}
if parts[0].Type != "thinking" || parts[0].Text != "思考..." {
t.Fatalf("first part mismatch: %#v", parts[0])
}
if parts[1].Type != "text" || parts[1].Text != "回答..." {
t.Fatalf("second part mismatch: %#v", parts[1])
}
}
func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"fragments": []any{
map[string]any{
"type": "RESPONSE",
"content": "直接回答",
},
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" {
t.Fatalf("unexpected parts for direct fragments: %#v", parts)
}
}
func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"fragments": []any{
map[string]any{
"type": "CUSTOM",
"content": "custom content",
},
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected partType fallback for unknown type, got %#v", parts)
}
}
func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"fragments": []any{
map[string]any{
"type": "RESPONSE",
"content": "",
},
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for empty fragment content, got %#v", parts)
}
}
// ─── ParseSSEChunkForContent: fragments/-1/content path ──────────────
func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments/-1/content",
"v": "继续思考",
}
parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking")
if len(parts) != 1 || parts[0].Type != "thinking" {
t.Fatalf("expected inherited thinking type, got %#v", parts)
}
}
// ─── IsCitation edge cases ───────────────────────────────────────────
func TestIsCitationWithLeadingWhitespace(t *testing.T) {
if !IsCitation(" [citation:2] text") {
t.Fatal("expected citation true with leading whitespace")
}
}
func TestIsCitationEmpty(t *testing.T) {
if IsCitation("") {
t.Fatal("expected citation false for empty string")
}
}
func TestIsCitationSimilarPrefix(t *testing.T) {
if IsCitation("[cite:1] text") {
t.Fatal("expected citation false for [cite: prefix")
}
}
// ─── extractContentRecursive edge cases ──────────────────────────────
func TestExtractContentRecursiveFinishedStatus(t *testing.T) {
items := []any{
map[string]any{"p": "status", "v": "FINISHED"},
}
parts, finished := extractContentRecursive(items, "text")
if !finished {
t.Fatal("expected finished on status FINISHED")
}
if len(parts) != 0 {
t.Fatalf("expected no parts, got %#v", parts)
}
}
func TestExtractContentRecursiveSkipsPath(t *testing.T) {
items := []any{
map[string]any{"p": "token_usage", "v": "data"},
}
parts, finished := extractContentRecursive(items, "text")
if finished {
t.Fatal("expected not finished")
}
if len(parts) != 0 {
t.Fatalf("expected no parts for skipped path, got %#v", parts)
}
}
func TestExtractContentRecursiveContentField(t *testing.T) {
items := []any{
map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" {
t.Fatalf("unexpected parts: %#v", parts)
}
}
func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) {
items := []any{
map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 1 || parts[0].Type != "thinking" {
t.Fatalf("expected thinking type for THINK content, got %#v", parts)
}
}
func TestExtractContentRecursiveThinkingPath(t *testing.T) {
items := []any{
map[string]any{"p": "thinking_content", "v": "deep thought"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" {
t.Fatalf("unexpected parts for thinking path: %#v", parts)
}
}
func TestExtractContentRecursiveContentPath(t *testing.T) {
items := []any{
map[string]any{"p": "content", "v": "text content"},
}
parts, _ := extractContentRecursive(items, "thinking")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for content path, got %#v", parts)
}
}
func TestExtractContentRecursiveResponsePath(t *testing.T) {
items := []any{
map[string]any{"p": "response", "v": "text content"},
}
parts, _ := extractContentRecursive(items, "thinking")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for response path, got %#v", parts)
}
}
func TestExtractContentRecursiveFragmentsPath(t *testing.T) {
items := []any{
map[string]any{"p": "fragments", "v": "fragment text"},
}
parts, _ := extractContentRecursive(items, "thinking")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for fragments path, got %#v", parts)
}
}
func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) {
items := []any{
map[string]any{
"p": "fragments",
"v": []any{
map[string]any{"content": "thought", "type": "THINKING"},
map[string]any{"content": "answer", "type": "RESPONSE"},
"raw string",
},
},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts)
}
if parts[0].Type != "thinking" || parts[0].Text != "thought" {
t.Fatalf("first part mismatch: %#v", parts[0])
}
if parts[1].Type != "text" || parts[1].Text != "answer" {
t.Fatalf("second part mismatch: %#v", parts[1])
}
if parts[2].Type != "text" || parts[2].Text != "raw string" {
t.Fatalf("third part mismatch: %#v", parts[2])
}
}
func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) {
items := []any{
map[string]any{
"p": "fragments",
"v": []any{
map[string]any{"content": "", "type": "RESPONSE"},
},
},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for empty nested content, got %#v", parts)
}
}
func TestExtractContentRecursiveFinishedString(t *testing.T) {
items := []any{
map[string]any{"p": "content", "v": "FINISHED"},
}
parts, _ := extractContentRecursive(items, "text")
// "FINISHED" string value on non-status path should be skipped
if len(parts) != 0 {
t.Fatalf("expected FINISHED string to be skipped, got %#v", parts)
}
}
func TestExtractContentRecursiveNoVField(t *testing.T) {
items := []any{
map[string]any{"p": "content"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for missing v field, got %#v", parts)
}
}
func TestExtractContentRecursiveNonMapItem(t *testing.T) {
items := []any{"just a string", 42}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for non-map items, got %#v", parts)
}
}

View File

@@ -0,0 +1,177 @@
package sse
import (
"context"
"io"
"strings"
"testing"
)
func TestStartParsedLinePumpEmptyBody(t *testing.T) {
body := strings.NewReader("")
results, done := StartParsedLinePump(context.Background(), body, false, "text")
collected := make([]LineResult, 0)
for r := range results {
collected = append(collected, r)
}
if err := <-done; err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(collected) != 0 {
t.Fatalf("expected no results for empty body, got %d", len(collected))
}
}
func TestStartParsedLinePumpMultipleLines(t *testing.T) {
body := strings.NewReader(
"data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"text\"}\n" +
"data: [DONE]\n",
)
results, done := StartParsedLinePump(context.Background(), body, true, "thinking")
collected := make([]LineResult, 0)
for r := range results {
collected = append(collected, r)
}
if err := <-done; err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(collected) < 3 {
t.Fatalf("expected at least 3 results, got %d", len(collected))
}
// First should be thinking
if collected[0].Parts[0].Type != "thinking" {
t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type)
}
// Last should be stop
last := collected[len(collected)-1]
if !last.Stop {
t.Fatal("expected last result to be stop")
}
}
func TestStartParsedLinePumpTypeTracking(t *testing.T) {
body := strings.NewReader(
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" +
"data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" +
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" +
"data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" +
"data: [DONE]\n",
)
results, done := StartParsedLinePump(context.Background(), body, true, "text")
types := make([]string, 0)
for r := range results {
for _, p := range r.Parts {
types = append(types, p.Type)
}
}
<-done
// Should have: thinking, thinking, text, text
expected := []string{"thinking", "thinking", "text", "text"}
if len(types) != len(expected) {
t.Fatalf("expected types %v, got %v", expected, types)
}
for i, want := range expected {
if types[i] != want {
t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types)
}
}
}
func TestStartParsedLinePumpContextCancellation(t *testing.T) {
pr, pw := io.Pipe()
ctx, cancel := context.WithCancel(context.Background())
results, done := StartParsedLinePump(ctx, pr, false, "text")
// Write one line to allow it to start
go func() {
_, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n")
// Don't close yet - wait for context cancel
}()
// Read first result
r := <-results
if !r.Parsed || len(r.Parts) == 0 {
t.Fatalf("expected first parsed result, got %#v", r)
}
// Cancel context - this will cause the pump to exit on next send
cancel()
// Close the pipe to unblock scanner.Scan()
pw.Close()
// Drain remaining results
for range results {
}
err := <-done
// Error may be context.Canceled or nil (if pipe closed first)
if err != nil && err != context.Canceled {
t.Fatalf("expected context.Canceled or nil error, got %v", err)
}
}
func TestStartParsedLinePumpOnlyDONE(t *testing.T) {
body := strings.NewReader("data: [DONE]\n")
results, done := StartParsedLinePump(context.Background(), body, false, "text")
collected := make([]LineResult, 0)
for r := range results {
collected = append(collected, r)
}
<-done
if len(collected) != 1 {
t.Fatalf("expected 1 result, got %d", len(collected))
}
if !collected[0].Stop {
t.Fatal("expected stop on [DONE]")
}
}
func TestStartParsedLinePumpNonSSELines(t *testing.T) {
body := strings.NewReader(
"event: update\n" +
": comment line\n" +
"data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" +
"data: [DONE]\n",
)
results, done := StartParsedLinePump(context.Background(), body, false, "text")
var validCount int
for r := range results {
if r.Parsed && len(r.Parts) > 0 {
validCount++
}
}
<-done
if validCount != 1 {
t.Fatalf("expected 1 valid result, got %d", validCount)
}
}
func TestStartParsedLinePumpThinkingDisabled(t *testing.T) {
body := strings.NewReader(
"data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"response\"}\n" +
"data: [DONE]\n",
)
// With thinking disabled, thinking content should still be emitted but marked differently
results, done := StartParsedLinePump(context.Background(), body, false, "text")
var parts []ContentPart
for r := range results {
parts = append(parts, r.Parts...)
}
<-done
if len(parts) < 1 {
t.Fatalf("expected at least 1 part, got %d", len(parts))
}
}

View File

@@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef {
{ID: "healthz_ok", Run: r.caseHealthz},
{ID: "readyz_ok", Run: r.caseReadyz},
{ID: "models_openai", Run: r.caseModelsOpenAI},
{ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
{ID: "models_claude", Run: r.caseModelsClaude},
{ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
{ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
{ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
{ID: "chat_stream_basic", Run: r.caseChatStream},
{ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
{ID: "responses_stream_basic", Run: r.caseResponsesStream},
{ID: "embeddings_contract", Run: r.caseEmbeddings},
{ID: "reasoner_stream", Run: r.caseReasonerStream},
{ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
{ID: "toolcall_stream", Run: r.caseToolcallStream},
@@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
return nil
}
func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
if err != nil {
@@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
return nil
}
func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/responses",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": "请简要回答 hello",
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
responseID := asString(m["id"])
cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
if responseID != "" {
getResp, getErr := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/v1/responses/" + responseID,
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Retryable: true,
})
if getErr != nil {
return getErr
}
cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
}
return nil
}
func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/responses",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": "请流式回答 hello",
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
frames, done := parseSSEFrames(resp.Body)
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
hasCreated := false
hasCompleted := false
for _, f := range frames {
switch asString(f["type"]) {
case "response.created":
hasCreated = true
case "response.completed":
hasCompleted = true
}
}
cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("done_terminated", done, "expected [DONE]")
return nil
}
func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/embeddings",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": []string{"hello", "world"},
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
if resp.StatusCode == http.StatusOK {
cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
data, _ := m["data"].([]any)
cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
errObj, _ := m["error"].(map[string]any)
_, hasCode := errObj["code"]
_, hasParam := errObj["param"]
cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,

View File

@@ -1,6 +1,8 @@
package util
import (
"encoding/json"
"fmt"
"regexp"
"strings"
@@ -68,15 +70,25 @@ func normalizeContent(v any) string {
if !ok {
continue
}
if m["type"] == "text" {
typeStr, _ := m["type"].(string)
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
if txt, ok := m["text"].(string); ok {
parts = append(parts, txt)
continue
}
if txt, ok := m["content"].(string); ok {
parts = append(parts, txt)
}
}
}
return strings.Join(parts, "\n")
default:
return ""
b, err := json.Marshal(v)
if err != nil {
return fmt.Sprintf("%v", v)
}
return string(b)
}
}

View File

@@ -33,6 +33,33 @@ func TestMessagesPrepareRoles(t *testing.T) {
}
}
func TestMessagesPrepareObjectContent(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": map[string]any{"temp": 18, "ok": true}},
}
got := MessagesPrepare(messages)
if !contains(got, `"temp":18`) || !contains(got, `"ok":true`) {
t.Fatalf("expected serialized object content, got %q", got)
}
}
func TestMessagesPrepareArrayTextVariants(t *testing.T) {
messages := []map[string]any{
{
"role": "user",
"content": []any{
map[string]any{"type": "output_text", "text": "line1"},
map[string]any{"type": "input_text", "text": "line2"},
map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"},
},
},
}
got := MessagesPrepare(messages)
if got != "line1\nline2" {
t.Fatalf("unexpected content from text variants: %q", got)
}
}
func TestConvertClaudeToDeepSeek(t *testing.T) {
store := config.LoadStore()
req := map[string]any{

140
internal/util/render.go Normal file
View File

@@ -0,0 +1,140 @@
package util
import (
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := ParseToolCalls(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" {
messageObj["reasoning_content"] = finalThinking
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = FormatOpenAIToolCalls(detected)
messageObj["content"] = nil
}
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
return map[string]any{
"id": completionID,
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
"usage": map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
},
}
}
func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := ParseToolCalls(finalText, toolNames)
exposedOutputText := finalText
output := make([]any, 0, 2)
if len(detected) > 0 {
// Keep structured tool output only; avoid leaking raw tool-call JSON
// into response.output_text for clients reading completed responses.
exposedOutputText = ""
toolCalls := make([]any, 0, len(detected))
for _, tc := range detected {
toolCalls = append(toolCalls, map[string]any{
"type": "tool_call",
"name": tc.Name,
"arguments": tc.Input,
})
}
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": toolCalls,
})
} else {
content := []any{
map[string]any{
"type": "output_text",
"text": finalText,
},
}
if finalThinking != "" {
content = append([]any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}}, content...)
}
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"role": "assistant",
"content": content,
})
}
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
return map[string]any{
"id": responseID,
"type": "response",
"object": "response",
"created_at": time.Now().Unix(),
"status": "completed",
"model": model,
"output": output,
"output_text": exposedOutputText,
"usage": map[string]any{
"input_tokens": promptTokens,
"output_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
},
}
}
func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any {
detected := ParseToolCalls(finalText, toolNames)
content := make([]map[string]any, 0, 4)
if finalThinking != "" {
content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking})
}
stopReason := "end_turn"
if len(detected) > 0 {
stopReason = "tool_use"
for i, tc := range detected {
content = append(content, map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
"name": tc.Name,
"input": tc.Input,
})
}
} else {
if finalText == "" {
finalText = "抱歉,没有生成有效的响应内容。"
}
content = append(content, map[string]any{"type": "text", "text": finalText})
}
return map[string]any{
"id": messageID,
"type": "message",
"role": "assistant",
"model": model,
"content": content,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
"output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText),
},
}
}

View File

@@ -0,0 +1,93 @@
package util
func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
return map[string]any{
"delta": delta,
"index": index,
}
}
func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any {
return map[string]any{
"delta": map[string]any{},
"index": index,
"finish_reason": finishReason,
}
}
func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
out := map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": choices,
}
if len(usage) > 0 {
out["usage"] = usage
}
return out
}
func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
}
}
func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
"type": "response.created",
"id": responseID,
"object": "response",
"model": model,
"status": "in_progress",
}
}
func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"delta": delta,
}
}
func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning.delta",
"id": responseID,
"delta": delta,
}
}
func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"tool_calls": toolCalls,
}
}
func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"tool_calls": toolCalls,
}
}
func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any {
return map[string]any{
"type": "response.completed",
"response": response,
}
}

View File

@@ -0,0 +1,48 @@
package util
import "testing"
func TestBuildOpenAIChatStreamChunk(t *testing.T) {
chunk := BuildOpenAIChatStreamChunk(
"cid",
123,
"deepseek-chat",
[]map[string]any{BuildOpenAIChatStreamDeltaChoice(0, map[string]any{"role": "assistant"})},
nil,
)
if chunk["object"] != "chat.completion.chunk" {
t.Fatalf("unexpected object: %#v", chunk["object"])
}
choices, _ := chunk["choices"].([]map[string]any)
if len(choices) == 0 {
rawChoices, _ := chunk["choices"].([]any)
if len(rawChoices) == 0 {
t.Fatalf("expected choices")
}
}
}
func TestBuildOpenAIChatUsage(t *testing.T) {
usage := BuildOpenAIChatUsage("prompt", "think", "answer")
if _, ok := usage["prompt_tokens"]; !ok {
t.Fatalf("expected prompt_tokens")
}
if _, ok := usage["completion_tokens_details"]; !ok {
t.Fatalf("expected completion_tokens_details")
}
}
func TestBuildOpenAIResponsesEventPayloads(t *testing.T) {
created := BuildOpenAIResponsesCreatedPayload("resp_1", "gpt-4o")
if created["type"] != "response.created" {
t.Fatalf("unexpected type: %#v", created["type"])
}
done := BuildOpenAIResponsesToolCallDonePayload("resp_1", []map[string]any{{"index": 0}})
if done["type"] != "response.output_tool_call.done" {
t.Fatalf("unexpected type: %#v", done["type"])
}
completed := BuildOpenAIResponsesCompletedPayload(map[string]any{"id": "resp_1"})
if completed["type"] != "response.completed" {
t.Fatalf("unexpected type: %#v", completed["type"])
}
}

View File

@@ -0,0 +1,94 @@
package util
import "testing"
func TestBuildOpenAIChatCompletionWithToolCalls(t *testing.T) {
out := BuildOpenAIChatCompletion(
"cid1",
"deepseek-chat",
"prompt",
"",
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
[]string{"search"},
)
if out["object"] != "chat.completion" {
t.Fatalf("unexpected object: %#v", out["object"])
}
choices, _ := out["choices"].([]map[string]any)
if len(choices) == 0 {
// json-like map from generic marshalling may be []any in some paths
rawChoices, _ := out["choices"].([]any)
if len(rawChoices) == 0 {
t.Fatalf("expected choices")
}
c0, _ := rawChoices[0].(map[string]any)
if c0["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", c0["finish_reason"])
}
return
}
if choices[0]["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choices[0]["finish_reason"])
}
}
func TestBuildOpenAIResponseObjectWithText(t *testing.T) {
out := BuildOpenAIResponseObject(
"resp_1",
"gpt-4o",
"prompt",
"reasoning",
"text",
nil,
)
if out["object"] != "response" {
t.Fatalf("unexpected object: %#v", out["object"])
}
output, _ := out["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected output entries")
}
first, _ := output[0].(map[string]any)
if first["type"] != "message" {
t.Fatalf("expected first output type message, got %#v", first["type"])
}
}
func TestBuildOpenAIResponseObjectToolCallsHidesRawOutputText(t *testing.T) {
out := BuildOpenAIResponseObject(
"resp_2",
"gpt-4o",
"prompt",
"",
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
[]string{"search"},
)
if out["output_text"] != "" {
t.Fatalf("expected empty output_text for tool_calls, got %#v", out["output_text"])
}
output, _ := out["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected output entries")
}
first, _ := output[0].(map[string]any)
if first["type"] != "tool_calls" {
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
}
}
func TestBuildClaudeMessageResponseToolUse(t *testing.T) {
out := BuildClaudeMessageResponse(
"msg_1",
"claude-sonnet-4-5",
[]any{map[string]any{"role": "user", "content": "hi"}},
"",
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
[]string{"search"},
)
if out["type"] != "message" {
t.Fatalf("unexpected type: %#v", out["type"])
}
if out["stop_reason"] != "tool_use" {
t.Fatalf("expected stop_reason=tool_use, got %#v", out["stop_reason"])
}
}

View File

@@ -0,0 +1,30 @@
package util
type StandardRequest struct {
Surface string
RequestedModel string
ResolvedModel string
ResponseModel string
Messages []any
FinalPrompt string
ToolNames []string
Stream bool
Thinking bool
Search bool
PassThrough map[string]any
}
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": r.FinalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": r.Thinking,
"search_enabled": r.Search,
}
for k, v := range r.PassThrough {
payload[k] = v
}
return payload
}

View File

@@ -10,6 +10,7 @@ import (
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
type ParsedToolCall struct {
Name string `json:"name"`
@@ -20,6 +21,10 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
if strings.TrimSpace(text) == "" {
return nil
}
text = stripFencedCodeBlocks(text)
if strings.TrimSpace(text) == "" {
return nil
}
candidates := buildToolCallCandidates(text)
var parsed []ParsedToolCall
@@ -33,6 +38,34 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return nil
}
return filterToolCalls(parsed, availableToolNames)
}
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil
}
if looksLikeToolExampleContext(trimmed) {
return nil
}
candidates := []string{trimmed}
for _, candidate := range candidates {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
continue
}
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") {
continue
}
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
return filterToolCalls(parsed, availableToolNames)
}
}
return nil
}
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
allowed := map[string]struct{}{}
for _, name := range availableToolNames {
allowed[name] = struct{}{}
@@ -283,6 +316,21 @@ func extractJSONObject(text string, start int) (string, int, bool) {
return "", 0, false
}
func looksLikeToolExampleContext(text string) bool {
t := strings.ToLower(strings.TrimSpace(text))
if t == "" {
return false
}
return strings.Contains(t, "```")
}
func stripFencedCodeBlocks(text string) string {
if strings.TrimSpace(text) == "" {
return ""
}
return fencedBlockPattern.ReplaceAllString(text, " ")
}
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
out := make([]map[string]any, 0, len(calls))
for _, c := range calls {

View File

@@ -19,11 +19,8 @@ func TestParseToolCalls(t *testing.T) {
func TestParseToolCallsFromFencedJSON(t *testing.T) {
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Input["q"] != "news" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
if len(calls) != 0 {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
}
}
@@ -62,3 +59,23 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
t.Fatalf("unexpected function name: %#v", fn)
}
}
func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 {
t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls)
}
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
calls := ParseStandaloneToolCalls(standalone, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected standalone parser to match, got %#v", calls)
}
}
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
}
}

View File

@@ -0,0 +1,429 @@
package util
import (
"encoding/json"
"net/http/httptest"
"strings"
"testing"
"ds2api/internal/config"
)
// ─── EstimateTokens edge cases ───────────────────────────────────────
func TestEstimateTokensEmpty(t *testing.T) {
if got := EstimateTokens(""); got != 0 {
t.Fatalf("expected 0 for empty string, got %d", got)
}
}
func TestEstimateTokensShortASCII(t *testing.T) {
got := EstimateTokens("ab")
if got != 1 {
t.Fatalf("expected 1 for 2 ascii chars, got %d", got)
}
}
func TestEstimateTokensLongASCII(t *testing.T) {
got := EstimateTokens(strings.Repeat("x", 100))
if got != 25 {
t.Fatalf("expected 25 for 100 ascii chars, got %d", got)
}
}
func TestEstimateTokensChinese(t *testing.T) {
got := EstimateTokens("你好世界")
if got < 1 {
t.Fatalf("expected at least 1 token for Chinese text, got %d", got)
}
}
func TestEstimateTokensMixed(t *testing.T) {
got := EstimateTokens("Hello 你好世界")
if got < 2 {
t.Fatalf("expected at least 2 tokens for mixed text, got %d", got)
}
}
func TestEstimateTokensSingleByte(t *testing.T) {
got := EstimateTokens("x")
if got != 1 {
t.Fatalf("expected 1 for single char (minimum), got %d", got)
}
}
func TestEstimateTokensSingleChinese(t *testing.T) {
got := EstimateTokens("你")
if got != 1 {
t.Fatalf("expected 1 for single Chinese char, got %d", got)
}
}
// ─── ToBool edge cases ───────────────────────────────────────────────
func TestToBoolTrue(t *testing.T) {
if !ToBool(true) {
t.Fatal("expected true")
}
}
func TestToBoolFalse(t *testing.T) {
if ToBool(false) {
t.Fatal("expected false")
}
}
func TestToBoolNonBool(t *testing.T) {
if ToBool("true") {
t.Fatal("expected false for string 'true'")
}
if ToBool(1) {
t.Fatal("expected false for int 1")
}
if ToBool(nil) {
t.Fatal("expected false for nil")
}
}
// ─── IntFrom edge cases ─────────────────────────────────────────────
func TestIntFromFloat64(t *testing.T) {
if got := IntFrom(float64(42.5)); got != 42 {
t.Fatalf("expected 42 for float64(42.5), got %d", got)
}
}
func TestIntFromInt(t *testing.T) {
if got := IntFrom(int(42)); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestIntFromInt64(t *testing.T) {
if got := IntFrom(int64(42)); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestIntFromString(t *testing.T) {
if got := IntFrom("42"); got != 0 {
t.Fatalf("expected 0 for string, got %d", got)
}
}
func TestIntFromNil(t *testing.T) {
if got := IntFrom(nil); got != 0 {
t.Fatalf("expected 0 for nil, got %d", got)
}
}
// ─── WriteJSON ───────────────────────────────────────────────────────
func TestWriteJSON(t *testing.T) {
rec := httptest.NewRecorder()
WriteJSON(rec, 200, map[string]any{"key": "value"})
if rec.Code != 200 {
t.Fatalf("expected 200, got %d", rec.Code)
}
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected application/json content type, got %q", ct)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode error: %v", err)
}
if body["key"] != "value" {
t.Fatalf("unexpected body: %#v", body)
}
}
func TestWriteJSONStatusCodes(t *testing.T) {
for _, code := range []int{200, 201, 400, 404, 500} {
rec := httptest.NewRecorder()
WriteJSON(rec, code, map[string]any{"status": code})
if rec.Code != code {
t.Fatalf("expected %d, got %d", code, rec.Code)
}
}
}
// ─── MessagesPrepare edge cases ──────────────────────────────────────
func TestMessagesPrepareEmpty(t *testing.T) {
got := MessagesPrepare(nil)
if got != "" {
t.Fatalf("expected empty for nil messages, got %q", got)
}
}
func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Hello"},
{"role": "user", "content": "World"},
}
got := MessagesPrepare(messages)
if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") {
t.Fatalf("expected both messages, got %q", got)
}
// Should be merged without <User> between them
count := strings.Count(got, "<User>")
if count != 0 {
t.Fatalf("expected no User marker for first message pair, got %d occurrences", count)
}
}
func TestMessagesPrepareAssistantMarkers(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
}
got := MessagesPrepare(messages)
if !strings.Contains(got, "<Assistant>") {
t.Fatalf("expected assistant marker, got %q", got)
}
if !strings.Contains(got, "<end▁of▁sentence>") {
t.Fatalf("expected end of sentence marker, got %q", got)
}
}
func TestMessagesPrepareUnknownRole(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Hello"},
{"role": "unknown_role", "content": "Unknown"},
}
got := MessagesPrepare(messages)
if !strings.Contains(got, "Unknown") {
t.Fatalf("expected unknown role content, got %q", got)
}
}
func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Look at this: ![alt](https://example.com/img.png)"},
}
got := MessagesPrepare(messages)
if strings.Contains(got, "![alt]") {
t.Fatalf("expected markdown image to be replaced, got %q", got)
}
}
func TestMessagesPrepareNilContent(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": nil},
}
got := MessagesPrepare(messages)
if got != "null" {
t.Logf("nil content handled as: %q", got)
}
}
// ─── normalizeContent edge cases ─────────────────────────────────────
func TestNormalizeContentString(t *testing.T) {
got := normalizeContent("hello")
if got != "hello" {
t.Fatalf("expected 'hello', got %q", got)
}
}
func TestNormalizeContentArray(t *testing.T) {
got := normalizeContent([]any{
map[string]any{"type": "text", "text": "line1"},
map[string]any{"type": "text", "text": "line2"},
})
if got != "line1\nline2" {
t.Fatalf("expected 'line1\\nline2', got %q", got)
}
}
func TestNormalizeContentArrayWithContentField(t *testing.T) {
got := normalizeContent([]any{
map[string]any{"type": "text", "content": "from-content"},
})
if got != "from-content" {
t.Fatalf("expected 'from-content', got %q", got)
}
}
func TestNormalizeContentArraySkipsImage(t *testing.T) {
got := normalizeContent([]any{
map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"},
map[string]any{"type": "text", "text": "caption"},
})
if strings.Contains(got, "image") {
t.Fatalf("expected image skipped, got %q", got)
}
if got != "caption" {
t.Fatalf("expected 'caption', got %q", got)
}
}
func TestNormalizeContentArrayNonMapItems(t *testing.T) {
got := normalizeContent([]any{"string item", 42})
if got != "" {
t.Fatalf("expected empty for non-map items, got %q", got)
}
}
func TestNormalizeContentJSON(t *testing.T) {
got := normalizeContent(map[string]any{"key": "value"})
if !strings.Contains(got, `"key":"value"`) {
t.Fatalf("expected JSON serialized, got %q", got)
}
}
// ─── ConvertClaudeToDeepSeek edge cases ──────────────────────────────
func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
}
out := ConvertClaudeToDeepSeek(req, store)
if out["model"] == "" {
t.Fatal("expected default model")
}
}
func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
"stop_sequences": []any{"\n\n"},
}
out := ConvertClaudeToDeepSeek(req, store)
if out["stop"] == nil {
t.Fatal("expected stop field from stop_sequences")
}
}
func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
"temperature": 0.7,
"top_p": 0.9,
}
out := ConvertClaudeToDeepSeek(req, store)
if out["temperature"] != 0.7 {
t.Fatalf("expected temperature 0.7, got %v", out["temperature"])
}
if out["top_p"] != 0.9 {
t.Fatalf("expected top_p 0.9, got %v", out["top_p"])
}
}
func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
}
out := ConvertClaudeToDeepSeek(req, store)
msgs, _ := out["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("expected 1 message without system, got %d", len(msgs))
}
}
func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
store := config.LoadStore()
req := map[string]any{
"model": "claude-opus-4-6",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
}
out := ConvertClaudeToDeepSeek(req, store)
if out["model"] != "deepseek-reasoner" {
t.Fatalf("expected opus to use slow mapping, got %q", out["model"])
}
}
// ─── FormatOpenAIStreamToolCalls ─────────────────────────────────────
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
{Name: "search", Input: map[string]any{"q": "test"}},
})
if len(formatted) != 1 {
t.Fatalf("expected 1, got %d", len(formatted))
}
fn, _ := formatted[0]["function"].(map[string]any)
if fn["name"] != "search" {
t.Fatalf("unexpected function name: %#v", fn)
}
if formatted[0]["index"] != 0 {
t.Fatalf("expected index 0, got %v", formatted[0]["index"])
}
}
// ─── ParseToolCalls more edge cases ──────────────────────────────────
func TestParseToolCallsNoToolNames(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
calls := ParseToolCalls(text, nil)
if len(calls) != 1 {
t.Fatalf("expected 1 call with nil tool names, got %d", len(calls))
}
}
func TestParseToolCallsEmptyText(t *testing.T) {
calls := ParseToolCalls("", []string{"search"})
if len(calls) != 0 {
t.Fatalf("expected 0 calls for empty text, got %d", len(calls))
}
}
func TestParseToolCallsMultipleTools(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}`
calls := ParseToolCalls(text, []string{"search", "get_weather"})
if len(calls) != 2 {
t.Fatalf("expected 2 calls, got %d", len(calls))
}
}
func TestParseToolCallsInputAsString(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Input["q"] != "golang" {
t.Fatalf("expected parsed string input, got %#v", calls[0].Input)
}
}
func TestParseToolCallsWithFunctionWrapper(t *testing.T) {
text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}`
calls := ParseToolCalls(text, []string{"calc"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "calc" {
t.Fatalf("expected calc, got %q", calls[0].Name)
}
}
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
if len(calls) != 0 {
t.Fatalf("expected fenced code block ignored, got %d calls", len(calls))
}
}
// ─── looksLikeToolExampleContext ─────────────────────────────────────
func TestLooksLikeToolExampleContextNone(t *testing.T) {
if looksLikeToolExampleContext("I will call the tool now") {
t.Fatal("expected false for non-example context")
}
}
func TestLooksLikeToolExampleContextFenced(t *testing.T) {
if !looksLikeToolExampleContext("```json") {
t.Fatal("expected true for fenced code block context")
}
}

View File

@@ -9,6 +9,12 @@
"apiKey": "your-api-key"
},
"models": {
"gpt-4o": {
"name": "GPT-4o (aliased to deepseek-chat)"
},
"gpt-5-codex": {
"name": "GPT-5 Codex (aliased to deepseek-reasoner)"
},
"deepseek-chat": {
"name": "DeepSeek Chat (DS2API)"
},
@@ -18,5 +24,5 @@
}
}
},
"model": "ds2api/deepseek-chat"
"model": "ds2api/gpt-5-codex"
}

View File

@@ -39,6 +39,10 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
const [loadingAccounts, setLoadingAccounts] = useState(false)
const apiFetch = authFetch || fetch
const resolveAccountIdentifier = (acc) => {
if (!acc || typeof acc !== 'object') return ''
return String(acc.identifier || acc.email || acc.mobile || '').trim()
}
const fetchAccounts = async (targetPage = page) => {
setLoadingAccounts(true)
@@ -147,9 +151,14 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
}
const deleteAccount = async (id) => {
const identifier = String(id || '').trim()
if (!identifier) {
onMessage('error', t('accountManager.invalidIdentifier'))
return
}
if (!confirm(t('accountManager.deleteAccountConfirm'))) return
try {
const res = await apiFetch(`/admin/accounts/${encodeURIComponent(id)}`, { method: 'DELETE' })
const res = await apiFetch(`/admin/accounts/${encodeURIComponent(identifier)}`, { method: 'DELETE' })
if (res.ok) {
onMessage('success', t('messages.deleted'))
fetchAccounts() // 刷新当前页
@@ -163,24 +172,29 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
}
const testAccount = async (identifier) => {
setTesting(prev => ({ ...prev, [identifier]: true }))
const accountID = String(identifier || '').trim()
if (!accountID) {
onMessage('error', t('accountManager.invalidIdentifier'))
return
}
setTesting(prev => ({ ...prev, [accountID]: true }))
try {
const res = await apiFetch('/admin/accounts/test', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ identifier }),
body: JSON.stringify({ identifier: accountID }),
})
const data = await res.json()
const statusMessage = data.success
? t('apiTester.testSuccess', { account: identifier, time: data.response_time })
: `${identifier}: ${data.message}`
? t('apiTester.testSuccess', { account: accountID, time: data.response_time })
: `${accountID}: ${data.message}`
onMessage(data.success ? 'success' : 'error', statusMessage)
fetchAccounts() // 刷新当前页
onRefresh()
} catch (e) {
onMessage('error', t('accountManager.testFailed', { error: e.message }))
} finally {
setTesting(prev => ({ ...prev, [identifier]: false }))
setTesting(prev => ({ ...prev, [accountID]: false }))
}
}
@@ -197,7 +211,12 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
for (let i = 0; i < allAccounts.length; i++) {
const acc = allAccounts[i]
const id = acc.email || acc.mobile
const id = resolveAccountIdentifier(acc)
if (!id) {
results.push({ id: '-', success: false, message: t('accountManager.invalidIdentifier') })
setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] })
continue
}
try {
const res = await apiFetch('/admin/accounts/test', {
@@ -387,7 +406,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
<div className="p-8 text-center text-muted-foreground">{t('actions.loading')}</div>
) : accounts.length > 0 ? (
accounts.map((acc, i) => {
const id = acc.email || acc.mobile
const id = resolveAccountIdentifier(acc)
return (
<div key={i} className="p-4 flex flex-col md:flex-row md:items-center justify-between gap-4 hover:bg-muted/50 transition-colors">
<div className="flex items-center gap-3 min-w-0">
@@ -396,7 +415,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
acc.has_token ? "bg-emerald-500 shadow-[0_0_8px_rgba(16,185,129,0.5)]" : "bg-amber-500"
)} />
<div className="min-w-0">
<div className="font-medium truncate">{id}</div>
<div className="font-medium truncate">{id || '-'}</div>
<div className="flex items-center gap-2 text-xs text-muted-foreground mt-0.5">
<span>{acc.has_token ? t('accountManager.sessionActive') : t('accountManager.reauthRequired')}</span>
{acc.token_preview && (
@@ -419,7 +438,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
onClick={() => deleteAccount(id)}
className="p-1 lg:p-1.5 text-muted-foreground hover:text-destructive hover:bg-destructive/10 rounded-md transition-colors"
>
<Trash2 className="w-3.5 h-3.5 lg:w-4 h-4" />
<Trash2 className="w-3.5 h-3.5 lg:w-4 lg:h-4" />
</button>
</div>
</div>

View File

@@ -42,6 +42,10 @@ export default function ApiTester({ config, onMessage, authFetch }) {
const apiFetch = authFetch || fetch
const accounts = config.accounts || []
const resolveAccountIdentifier = (acc) => {
if (!acc || typeof acc !== 'object') return ''
return String(acc.identifier || acc.email || acc.mobile || '').trim()
}
const configuredKeys = config.keys || []
const trimmedApiKey = apiKey.trim()
const defaultKey = configuredKeys[0] || ''
@@ -297,11 +301,15 @@ return (
onChange={e => setSelectedAccount(e.target.value)}
>
<option value="" className="bg-popover text-popover-foreground">{t('apiTester.autoRandom')}</option>
{accounts.map((acc, i) => (
<option key={i} value={acc.email || acc.mobile} className="bg-popover text-popover-foreground">
👤 {acc.email || acc.mobile}
</option>
))}
{accounts.map((acc, i) => {
const id = resolveAccountIdentifier(acc)
if (!id) return null
return (
<option key={i} value={id} className="bg-popover text-popover-foreground">
👤 {id}
</option>
)
})}
</select>
<ChevronDown className="absolute right-2.5 top-3 w-4 h-4 text-muted-foreground pointer-events-none" />
</div>

View File

@@ -86,6 +86,7 @@
"requiredFields": "Password and email/mobile are required.",
"deleteKeyConfirm": "Are you sure you want to delete this API key?",
"deleteAccountConfirm": "Are you sure you want to delete this account?",
"invalidIdentifier": "Invalid account identifier. Operation aborted.",
"testAllConfirm": "Test API connectivity for all accounts?",
"testAllCompleted": "Completed: {success}/{total} available",
"testFailed": "Test failed: {error}",

View File

@@ -86,6 +86,7 @@
"requiredFields": "需要填写密码以及邮箱或手机号",
"deleteKeyConfirm": "确定要删除此 API 密钥吗?",
"deleteAccountConfirm": "确定要删除此账号吗?",
"invalidIdentifier": "账号标识无效,无法执行操作",
"testAllConfirm": "测试所有账号的 API 连通性?",
"testAllCompleted": "完成:{success}/{total} 可用",
"testFailed": "测试失败: {error}",