mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 23:45:27 +08:00
Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2c6445cfc | ||
|
|
b6fba47bcf | ||
|
|
e8d1aee7ad | ||
|
|
5cf56e7628 | ||
|
|
c291d333c4 | ||
|
|
2788e20f05 | ||
|
|
f178000d69 | ||
|
|
e840743295 | ||
|
|
77484bf813 | ||
|
|
f14969eca5 | ||
|
|
fe8a6bd3cd | ||
|
|
797ab77873 | ||
|
|
8f09e3b381 | ||
|
|
3a79b07d33 | ||
|
|
df13f35f43 | ||
|
|
4422f989be | ||
|
|
6052a8d1e2 | ||
|
|
f125c7ab83 | ||
|
|
8ff923cd77 | ||
|
|
e9a544cc53 | ||
|
|
d848d24a82 | ||
|
|
0a2fc42dad | ||
|
|
e615f1710f | ||
|
|
8f01aa224c | ||
|
|
31e64ff31d | ||
|
|
5984802df4 | ||
|
|
e0ed4ba238 | ||
|
|
ae37654893 | ||
|
|
aa7f821151 | ||
|
|
f7426f9f04 | ||
|
|
787e034174 | ||
|
|
d73f7b8b73 | ||
|
|
b8d844e2f6 | ||
|
|
2ba8b143d0 | ||
|
|
70603a5a90 | ||
|
|
fa51aafdc5 | ||
|
|
10d681ffe7 | ||
|
|
f313d0068f | ||
|
|
12256ceb24 | ||
|
|
2c08375b49 | ||
|
|
fa38934114 | ||
|
|
69eb71159d | ||
|
|
0e7f5cdc86 | ||
|
|
5b7cdaa729 | ||
|
|
08f32c4c40 | ||
|
|
69b7bc0c1a | ||
|
|
0f2b5fee23 | ||
|
|
26d195f2a6 | ||
|
|
790a8ca980 | ||
|
|
a1ce954ad5 | ||
|
|
6688e0ba35 | ||
|
|
c945f49fc4 | ||
|
|
0c644d1f4d | ||
|
|
146d59e7bf | ||
|
|
daf3307b88 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -62,3 +62,8 @@ CLAUDE.local.md
|
||||
|
||||
# Local tool bootstrap cache
|
||||
.tmp/
|
||||
|
||||
# Chat history
|
||||
data/
|
||||
.codex
|
||||
.roomodes
|
||||
|
||||
117
API.en.md
117
API.en.md
@@ -31,13 +31,13 @@ Docs: [Overview](README.en.md) / [Architecture](docs/ARCHITECTURE.en.md) / [Depl
|
||||
| Base URL | `http://localhost:5001` or your deployment domain |
|
||||
| Default Content-Type | `application/json` |
|
||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
### 3.0 Adapter-Layer Notes
|
||||
|
||||
- OpenAI / Claude / Gemini protocols are now mounted on one shared `chi` router tree assembled in `internal/server/router.go`.
|
||||
- Adapter responsibilities are streamlined to: **request normalization → DeepSeek invocation → protocol-shaped rendering**, reducing legacy split-logic paths.
|
||||
- Tool-calling semantics are aligned between Go and Node runtime: structured parsing first (JSON/XML/invoke/markup), plus stream-time anti-leak filtering.
|
||||
- Tool-calling semantics are aligned between Go and Node runtime: parsing is now centered on XML/Markup-family tool syntax (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml variants), plus stream-time anti-leak filtering.
|
||||
- `Admin API` separates static config from runtime policy: `/admin/config*` for configuration state, `/admin/settings*` for runtime behavior.
|
||||
|
||||
---
|
||||
@@ -108,6 +108,7 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) |
|
||||
| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) |
|
||||
| POST | `/v1/embeddings` | Business | OpenAI Embeddings API |
|
||||
| POST | `/v1/files` | Business | OpenAI Files upload (multipart/form-data) |
|
||||
| GET | `/anthropic/v1/models` | None | Claude model list |
|
||||
| POST | `/anthropic/v1/messages` | Business | Claude messages |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
|
||||
@@ -129,11 +130,19 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| POST | `/admin/settings/password` | Admin | Update admin password and invalidate old JWTs |
|
||||
| POST | `/admin/config/import` | Admin | Import config (merge/replace) |
|
||||
| GET | `/admin/config/export` | Admin | Export full config (`config`/`json`/`base64`) |
|
||||
| POST | `/admin/keys` | Admin | Add API key |
|
||||
| POST | `/admin/keys` | Admin | Add API key (optional `name`/`remark`) |
|
||||
| PUT | `/admin/keys/{key}` | Admin | Update API key metadata |
|
||||
| DELETE | `/admin/keys/{key}` | Admin | Delete API key |
|
||||
| GET | `/admin/proxies` | Admin | List proxies |
|
||||
| POST | `/admin/proxies` | Admin | Add proxy |
|
||||
| PUT | `/admin/proxies/{proxyID}` | Admin | Update proxy (empty password keeps old secret) |
|
||||
| DELETE | `/admin/proxies/{proxyID}` | Admin | Delete proxy (auto-unbind referenced accounts) |
|
||||
| POST | `/admin/proxies/test` | Admin | Test proxy connectivity |
|
||||
| GET | `/admin/accounts` | Admin | Paginated account list |
|
||||
| POST | `/admin/accounts` | Admin | Add account |
|
||||
| PUT | `/admin/accounts/{identifier}` | Admin | Update account name/remark |
|
||||
| DELETE | `/admin/accounts/{identifier}` | Admin | Delete account |
|
||||
| PUT | `/admin/accounts/{identifier}/proxy` | Admin | Bind/unbind proxy for an account |
|
||||
| GET | `/admin/queue/status` | Admin | Account queue status |
|
||||
| POST | `/admin/accounts/test` | Admin | Test one account |
|
||||
| POST | `/admin/accounts/test-all` | Admin | Test all accounts |
|
||||
@@ -149,6 +158,10 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| GET | `/admin/export` | Admin | Export config JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | Read local packet-capture entries |
|
||||
| DELETE | `/admin/dev/captures` | Admin | Clear local packet-capture entries |
|
||||
| GET | `/admin/chat-history` | Admin | Read server-side conversation history |
|
||||
| DELETE | `/admin/chat-history` | Admin | Clear server-side conversation history |
|
||||
| DELETE | `/admin/chat-history/{id}` | Admin | Delete one server-side conversation entry |
|
||||
| PUT | `/admin/chat-history/settings` | Admin | Update conversation history retention limit |
|
||||
| GET | `/admin/version` | Admin | Check current version and latest Release |
|
||||
|
||||
---
|
||||
@@ -208,6 +221,13 @@ For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-outp
|
||||
3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.).
|
||||
4. If still unmatched, return `invalid_request_error`.
|
||||
|
||||
Current built-in default aliases (excerpt):
|
||||
|
||||
- OpenAI: `gpt-4o`, `gpt-4.1`, `gpt-4.1-mini`, `gpt-4.1-nano`, `gpt-5`, `gpt-5-mini`, `gpt-5-codex`
|
||||
- OpenAI reasoning: `o1`, `o1-mini`, `o3`, `o3-mini`
|
||||
- Claude: `claude-sonnet-4-5`, `claude-haiku-4-5`, `claude-opus-4-6` (plus compatibility aliases `claude-3-5-sonnet` / `claude-3-5-haiku` / `claude-3-opus`)
|
||||
- Gemini: `gemini-2.5-pro`, `gemini-2.5-flash`
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**Headers**:
|
||||
@@ -221,7 +241,7 @@ Content-Type: application/json
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, `gemini-2.5-pro`, etc.) |
|
||||
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-5`, `gpt-5-mini`, `gpt-5-codex`, `o3`, `claude-opus-4-6`, `gemini-2.5-pro`, `gemini-2.5-flash`, etc.) |
|
||||
| `messages` | array | ✅ | OpenAI-style messages |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Function calling schema |
|
||||
@@ -312,7 +332,12 @@ When `tools` is present, DS2API performs anti-leak handling:
|
||||
}
|
||||
```
|
||||
|
||||
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`.
|
||||
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full argument closure), then keeps sending argument deltas; confirmed tool-call fragments are not forwarded as `delta.content`.
|
||||
|
||||
Additional notes:
|
||||
|
||||
- The parser currently follows XML/Markup-family tool payloads (`<tool_call>`, `<function_call>`, `<invoke>`, `tool_use`, antml variants). Standalone JSON `tool_calls` payloads are not treated as executable tool calls by default.
|
||||
- `tool_calls` shown inside fenced markdown code blocks (for example, ```json ... ```) are treated as examples, not executable calls.
|
||||
|
||||
---
|
||||
|
||||
@@ -391,6 +416,21 @@ Business auth required. Returns OpenAI-compatible embeddings shape.
|
||||
|
||||
> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501.
|
||||
|
||||
### `POST /v1/files`
|
||||
|
||||
Business auth required. OpenAI Files-compatible upload endpoint; currently only `multipart/form-data` is supported.
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `file` | file | ✅ | Binary payload |
|
||||
| `purpose` | string | ❌ | Forwarded purpose field |
|
||||
|
||||
Constraints and behavior:
|
||||
|
||||
- `Content-Type` must be `multipart/form-data` (otherwise `400`).
|
||||
- Total request size limit is `100 MiB` (over-limit returns `413`).
|
||||
- Success returns an OpenAI `file` object (`id/object/bytes/filename/purpose/status`, etc.) and includes `account_id` for source-account tracing.
|
||||
|
||||
---
|
||||
|
||||
## Claude-Compatible API
|
||||
@@ -609,11 +649,15 @@ Returns Vercel preconfiguration status.
|
||||
|
||||
### `GET /admin/config`
|
||||
|
||||
Returns sanitized config.
|
||||
Returns sanitized config, including both `keys` and `api_keys`.
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "Primary", "remark": "Production"},
|
||||
{"key": "k2", "name": "Backup", "remark": "Load test"}
|
||||
],
|
||||
"env_backed": false,
|
||||
"env_source_present": true,
|
||||
"env_writeback_enabled": true,
|
||||
@@ -637,13 +681,18 @@ Returns sanitized config.
|
||||
|
||||
### `POST /admin/config`
|
||||
|
||||
Only updates `keys`, `accounts`, and `claude_mapping`.
|
||||
Only updates `keys`, `api_keys`, `accounts`, and `claude_mapping`.
|
||||
If both `api_keys` and `keys` are sent, the structured `api_keys` entries win so `name` / `remark` metadata is preserved; `keys` remains a legacy fallback.
|
||||
|
||||
**Request**:
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "Primary", "remark": "Production"},
|
||||
{"key": "k2", "name": "Backup", "remark": "Load test"}
|
||||
],
|
||||
"accounts": [
|
||||
{"email": "user@example.com", "password": "pwd", "token": ""}
|
||||
],
|
||||
@@ -703,7 +752,7 @@ Imports full config with:
|
||||
|
||||
The request can send config directly, or wrapped as `{"config": {...}, "mode":"merge"}`.
|
||||
Query params `?mode=merge` / `?mode=replace` are also supported.
|
||||
Import accepts `keys`, `accounts`, `claude_mapping` / `claude_model_mapping`, `model_aliases`, `admin`, `runtime`, `responses`, `embeddings`, and `auto_delete`; legacy `toolcall` fields are ignored.
|
||||
Import accepts `keys`, `api_keys`, `accounts`, `claude_mapping` / `claude_model_mapping`, `model_aliases`, `admin`, `runtime`, `responses`, `embeddings`, and `auto_delete`; legacy `toolcall` fields are ignored.
|
||||
|
||||
> `compat` fields are managed via `/admin/settings` or the config file; this import endpoint does not update `compat`.
|
||||
|
||||
@@ -714,7 +763,17 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
### `POST /admin/keys`
|
||||
|
||||
```json
|
||||
{"key": "new-api-key"}
|
||||
{"key": "new-api-key", "name": "Primary", "remark": "Production"}
|
||||
```
|
||||
|
||||
**Response**: `{"success": true, "total_keys": 3}`
|
||||
|
||||
### `PUT /admin/keys/{key}`
|
||||
|
||||
Updates the `name` / `remark` of the specified API key. The path `key` is read-only and cannot be changed.
|
||||
|
||||
```json
|
||||
{"name": "Backup", "remark": "Load test"}
|
||||
```
|
||||
|
||||
**Response**: `{"success": true, "total_keys": 3}`
|
||||
@@ -723,6 +782,26 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
|
||||
**Response**: `{"success": true, "total_keys": 2}`
|
||||
|
||||
### `GET /admin/proxies`
|
||||
|
||||
Lists proxy configs (password is never returned; use `has_password` as a marker).
|
||||
|
||||
### `POST /admin/proxies`
|
||||
|
||||
Adds a proxy. Request accepts `id` (optional; auto-generated when omitted), `name`, `type` (`http` / `socks5`), `host`, `port`, `username`, `password`.
|
||||
|
||||
### `PUT /admin/proxies/{proxyID}`
|
||||
|
||||
Updates a proxy. If `password` is an empty string, the existing secret is preserved.
|
||||
|
||||
### `DELETE /admin/proxies/{proxyID}`
|
||||
|
||||
Deletes a proxy and automatically clears `proxy_id` on all accounts that reference it.
|
||||
|
||||
### `POST /admin/proxies/test`
|
||||
|
||||
Tests proxy connectivity: provide `proxy_id` to test a saved proxy; omit it to run a one-off test using proxy fields in the request body.
|
||||
|
||||
### `GET /admin/accounts`
|
||||
|
||||
**Query params**:
|
||||
@@ -730,7 +809,7 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
| Param | Default | Range |
|
||||
| --- | --- | --- |
|
||||
| `page` | `1` | ≥ 1 |
|
||||
| `page_size` | `10` | 1–100 |
|
||||
| `page_size` | `10` | 1–5000 |
|
||||
| `q` | empty | Filter by identifier / email / mobile |
|
||||
|
||||
**Response**:
|
||||
@@ -765,12 +844,30 @@ Returned items also include `test_status`, usually `ok` or `failed`.
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}`
|
||||
|
||||
Updates the `name` / `remark` of the specified account. The path `identifier` can be email or mobile and cannot be changed.
|
||||
|
||||
```json
|
||||
{"name": "Primary account", "remark": "Shared with the team"}
|
||||
```
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:<hash>`).
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 5}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}/proxy`
|
||||
|
||||
Updates proxy binding for a specific account.
|
||||
|
||||
- Request body: `{"proxy_id":"..."}`.
|
||||
- Use empty `proxy_id` to unbind proxy.
|
||||
- `identifier` supports email / mobile / token-only synthetic id.
|
||||
|
||||
### `GET /admin/queue/status`
|
||||
|
||||
```json
|
||||
|
||||
119
API.md
119
API.md
@@ -31,13 +31,13 @@
|
||||
| Base URL | `http://localhost:5001` 或你的部署域名 |
|
||||
| 默认 Content-Type | `application/json` |
|
||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
### 3.0 接口适配层说明
|
||||
|
||||
- OpenAI / Claude / Gemini 三套协议已统一挂在同一 `chi` 路由树上,由 `internal/server/router.go` 负责装配。
|
||||
- 适配器层职责收敛为:**请求归一化 → DeepSeek 调用 → 协议形态渲染**,减少历史版本中“同能力多处实现”的分叉。
|
||||
- Tool Calling 的解析策略在 Go 与 Node Runtime 间保持一致:优先结构化解析(JSON/XML/invoke/markup),并在流式场景执行防泄漏筛分。
|
||||
- Tool Calling 的解析策略在 Go 与 Node Runtime 间保持一致:当前以 XML/Markup 家族解析为主(含 `<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体),并在流式场景执行防泄漏筛分。
|
||||
- `Admin API` 将配置与运行时策略分开:`/admin/config*` 管静态配置,`/admin/settings*` 管运行时行为。
|
||||
|
||||
---
|
||||
@@ -108,6 +108,7 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) |
|
||||
| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) |
|
||||
| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 |
|
||||
| POST | `/v1/files` | 业务 | OpenAI Files 上传(multipart/form-data) |
|
||||
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
|
||||
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
|
||||
@@ -129,11 +130,19 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| POST | `/admin/settings/password` | Admin | 更新 Admin 密码并使旧 JWT 失效 |
|
||||
| POST | `/admin/config/import` | Admin | 导入配置(merge/replace) |
|
||||
| GET | `/admin/config/export` | Admin | 导出完整配置(含 `config`/`json`/`base64`) |
|
||||
| POST | `/admin/keys` | Admin | 添加 API key |
|
||||
| POST | `/admin/keys` | Admin | 添加 API key(可附 name/remark) |
|
||||
| PUT | `/admin/keys/{key}` | Admin | 更新 API key 备注信息 |
|
||||
| DELETE | `/admin/keys/{key}` | Admin | 删除 API key |
|
||||
| GET | `/admin/proxies` | Admin | 代理列表 |
|
||||
| POST | `/admin/proxies` | Admin | 添加代理 |
|
||||
| PUT | `/admin/proxies/{proxyID}` | Admin | 更新代理(留空 password 表示保留原密码) |
|
||||
| DELETE | `/admin/proxies/{proxyID}` | Admin | 删除代理(自动解绑引用该代理的账号) |
|
||||
| POST | `/admin/proxies/test` | Admin | 测试代理连通性 |
|
||||
| GET | `/admin/accounts` | Admin | 分页账号列表 |
|
||||
| POST | `/admin/accounts` | Admin | 添加账号 |
|
||||
| PUT | `/admin/accounts/{identifier}` | Admin | 更新账号 name/remark |
|
||||
| DELETE | `/admin/accounts/{identifier}` | Admin | 删除账号 |
|
||||
| PUT | `/admin/accounts/{identifier}/proxy` | Admin | 为账号绑定/解绑代理 |
|
||||
| GET | `/admin/queue/status` | Admin | 账号队列状态 |
|
||||
| POST | `/admin/accounts/test` | Admin | 测试单个账号 |
|
||||
| POST | `/admin/accounts/test-all` | Admin | 测试全部账号 |
|
||||
@@ -149,6 +158,10 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| GET | `/admin/export` | Admin | 导出配置 JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | 查看本地抓包记录 |
|
||||
| DELETE | `/admin/dev/captures` | Admin | 清空本地抓包记录 |
|
||||
| GET | `/admin/chat-history` | Admin | 查看服务器端对话记录 |
|
||||
| DELETE | `/admin/chat-history` | Admin | 清空服务器端对话记录 |
|
||||
| DELETE | `/admin/chat-history/{id}` | Admin | 删除单条服务器端对话记录 |
|
||||
| PUT | `/admin/chat-history/settings` | Admin | 更新对话记录保留条数 |
|
||||
| GET | `/admin/version` | Admin | 查询当前版本与最新 Release |
|
||||
|
||||
---
|
||||
@@ -208,6 +221,13 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。
|
||||
4. 仍未命中则返回 `invalid_request_error`。
|
||||
|
||||
当前内置默认 alias(节选):
|
||||
|
||||
- OpenAI:`gpt-4o`、`gpt-4.1`、`gpt-4.1-mini`、`gpt-4.1-nano`、`gpt-5`、`gpt-5-mini`、`gpt-5-codex`
|
||||
- OpenAI Reasoning:`o1`、`o1-mini`、`o3`、`o3-mini`
|
||||
- Claude:`claude-sonnet-4-5`、`claude-haiku-4-5`、`claude-opus-4-6`(及 `claude-3-5-sonnet` / `claude-3-5-haiku` / `claude-3-opus` 兼容别名)
|
||||
- Gemini:`gemini-2.5-pro`、`gemini-2.5-flash`
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**请求头**:
|
||||
@@ -221,7 +241,7 @@ Content-Type: application/json
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`、`gemini-2.5-pro` 等) |
|
||||
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-5`、`gpt-5-mini`、`gpt-5-codex`、`o3`、`claude-opus-4-6`、`gemini-2.5-pro`、`gemini-2.5-flash` 等) |
|
||||
| `messages` | array | ✅ | OpenAI 风格消息数组 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | Function Calling 定义 |
|
||||
@@ -312,12 +332,12 @@ data: [DONE]
|
||||
}
|
||||
```
|
||||
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整工具参数闭合),并持续发送 arguments 增量;已确认的工具调用片段不会回流到 `delta.content`。
|
||||
|
||||
补充说明:
|
||||
|
||||
- **非代码块上下文**下,工具负载即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||
- 解析器以 XML/Markup 为最高优先级,并兼容 JSON、ANTML、text-kv 等格式输入;最终按客户端协议转译为对应 tool call 结构(OpenAI/Claude/Gemini)。
|
||||
- 解析器当前走 XML/Markup 家族(包含 `<tool_call>`、`<function_call>`、`<invoke>`、`tool_use`、antml 风格);纯 JSON `tool_calls` 片段默认不会直接作为可执行调用解析。
|
||||
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||
|
||||
---
|
||||
@@ -397,6 +417,21 @@ data: [DONE]
|
||||
|
||||
> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。
|
||||
|
||||
### `POST /v1/files`
|
||||
|
||||
需要业务鉴权。兼容 OpenAI Files 上传接口,当前仅支持 `multipart/form-data`。
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `file` | file | ✅ | 上传文件二进制 |
|
||||
| `purpose` | string | ❌ | 透传到上游用途字段 |
|
||||
|
||||
约束与行为:
|
||||
|
||||
- 请求必须为 `multipart/form-data`,否则返回 `400`。
|
||||
- 请求体总大小上限 `100 MiB`(超限返回 `413`)。
|
||||
- 成功返回 OpenAI `file` 对象(`id/object/bytes/filename/purpose/status` 等字段),并附带 `account_id` 便于定位来源账号。
|
||||
|
||||
---
|
||||
|
||||
## Claude 兼容接口
|
||||
@@ -615,11 +650,15 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `GET /admin/config`
|
||||
|
||||
返回脱敏后的配置。
|
||||
返回脱敏后的配置,包含 `keys` 与 `api_keys`。
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "主 Key", "remark": "生产流量"},
|
||||
{"key": "k2", "name": "备用 Key", "remark": "压测"}
|
||||
],
|
||||
"env_backed": false,
|
||||
"env_source_present": true,
|
||||
"env_writeback_enabled": true,
|
||||
@@ -643,13 +682,18 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `POST /admin/config`
|
||||
|
||||
只更新 `keys`、`accounts`、`claude_mapping`。
|
||||
只更新 `keys`、`api_keys`、`accounts`、`claude_mapping`。
|
||||
如果同时发送 `api_keys` 与 `keys`,优先保留 `api_keys` 中的结构化 `name` / `remark`;`keys` 仅作为旧格式兼容回退。
|
||||
|
||||
**请求**:
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "主 Key", "remark": "生产流量"},
|
||||
{"key": "k2", "name": "备用 Key", "remark": "压测"}
|
||||
],
|
||||
"accounts": [
|
||||
{"email": "user@example.com", "password": "pwd", "token": ""}
|
||||
],
|
||||
@@ -709,7 +753,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
请求可直接传配置对象,或使用 `{"config": {...}, "mode":"merge"}` 包裹格式。
|
||||
也支持在查询参数里传 `?mode=merge` / `?mode=replace`。
|
||||
导入时会接受 `keys`、`accounts`、`claude_mapping` / `claude_model_mapping`、`model_aliases`、`admin`、`runtime`、`responses`、`embeddings`、`auto_delete` 等字段;`toolcall` 相关字段会被忽略。
|
||||
导入时会接受 `keys`、`api_keys`、`accounts`、`claude_mapping` / `claude_model_mapping`、`model_aliases`、`admin`、`runtime`、`responses`、`embeddings`、`auto_delete` 等字段;`toolcall` 相关字段会被忽略。
|
||||
|
||||
> `compat` 相关字段请通过 `/admin/settings` 或配置文件管理;该导入接口不会更新 `compat`。
|
||||
|
||||
@@ -717,10 +761,25 @@ data: {"type":"message_stop"}
|
||||
|
||||
导出完整配置,返回 `config`、`json`、`base64` 三种格式。
|
||||
|
||||
响应示例:
|
||||
|
||||
|
||||
> 注:`_vercel_sync_hash` 和 `_vercel_sync_time` 为内部同步元数据字段,用于 Vercel 配置漂移检测。
|
||||
|
||||
### `POST /admin/keys`
|
||||
|
||||
```json
|
||||
{"key": "new-api-key"}
|
||||
{"key": "new-api-key", "name": "主 Key", "remark": "生产流量"}
|
||||
```
|
||||
|
||||
**响应**:`{"success": true, "total_keys": 3}`
|
||||
|
||||
### `PUT /admin/keys/{key}`
|
||||
|
||||
更新指定 API key 的 `name` / `remark`,路径参数中的 `key` 为只读标识,不可修改。
|
||||
|
||||
```json
|
||||
{"name": "备用 Key", "remark": "压测"}
|
||||
```
|
||||
|
||||
**响应**:`{"success": true, "total_keys": 3}`
|
||||
@@ -729,6 +788,26 @@ data: {"type":"message_stop"}
|
||||
|
||||
**响应**:`{"success": true, "total_keys": 2}`
|
||||
|
||||
### `GET /admin/proxies`
|
||||
|
||||
列出代理配置(密码不回传,仅返回 `has_password` 标记)。
|
||||
|
||||
### `POST /admin/proxies`
|
||||
|
||||
新增代理。请求体支持 `id`(可选,未传则自动生成)、`name`、`type`(`http` / `socks5`)、`host`、`port`、`username`、`password`。
|
||||
|
||||
### `PUT /admin/proxies/{proxyID}`
|
||||
|
||||
更新指定代理。若请求中 `password` 为空字符串,则保留原密码。
|
||||
|
||||
### `DELETE /admin/proxies/{proxyID}`
|
||||
|
||||
删除代理,并自动清空所有引用该代理账号的 `proxy_id`。
|
||||
|
||||
### `POST /admin/proxies/test`
|
||||
|
||||
测试代理连通性:传 `proxy_id` 时测试已保存代理;不传时按请求体代理字段做临时连通性测试。
|
||||
|
||||
### `GET /admin/accounts`
|
||||
|
||||
**查询参数**:
|
||||
@@ -736,7 +815,7 @@ data: {"type":"message_stop"}
|
||||
| 参数 | 默认 | 范围 |
|
||||
| --- | --- | --- |
|
||||
| `page` | `1` | ≥ 1 |
|
||||
| `page_size` | `10` | 1–100 |
|
||||
| `page_size` | `10` | 1–5000 |
|
||||
| `q` | 空 | 按 identifier / email / mobile 过滤 |
|
||||
|
||||
**响应**:
|
||||
@@ -769,12 +848,30 @@ data: {"type":"message_stop"}
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}`
|
||||
|
||||
更新指定账号的 `name` / `remark`。路径参数中的 `identifier` 可以是 email 或 mobile,且不可修改。
|
||||
|
||||
```json
|
||||
{"name": "主账号", "remark": "团队共享"}
|
||||
```
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:<hash>`)。
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 5}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}/proxy`
|
||||
|
||||
更新指定账号绑定代理。
|
||||
|
||||
- 请求体:`{"proxy_id":"..."}`;
|
||||
- `proxy_id` 传空字符串时表示解绑代理;
|
||||
- `identifier` 支持 email / mobile / token-only 合成标识。
|
||||
|
||||
### `GET /admin/queue/status`
|
||||
|
||||
```json
|
||||
|
||||
143
LICENSE
143
LICENSE
@@ -1,5 +1,5 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
@@ -7,17 +7,15 @@
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
@@ -26,44 +24,34 @@ them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
@@ -72,7 +60,7 @@ modification follow.
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
@@ -549,35 +537,45 @@ to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
@@ -635,40 +633,29 @@ the "copyright" line and a pointer to where the full notice is found.
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
it under the terms of the GNU Affero General Public License as published
|
||||
by the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
|
||||
56
README.MD
56
README.MD
@@ -82,29 +82,19 @@ flowchart LR
|
||||
- **前端**:React 管理台(`webui/`),运行时托管静态构建产物
|
||||
- **部署**:本地运行、Docker、Vercel Serverless、Linux systemd
|
||||
|
||||
### 3.X 底层架构调整(相较旧版本)
|
||||
|
||||
- **统一路由内核**:所有协议入口统一汇聚到 `internal/server/router.go`,并在同一路由树中注册 OpenAI / Claude / Gemini / Admin / WebUI 路由,避免多入口行为漂移。
|
||||
- **统一执行链路**:Claude / Gemini 入口先经 `internal/translatorcliproxy` 做协议转换,再进入 `openai.ChatCompletions` 统一处理工具调用与流式语义,最后再转换回原协议响应。
|
||||
- **适配器分层更清晰**:`internal/adapter/{claude,gemini}` 负责入口/出口协议封装,`internal/adapter/openai` 负责核心执行,DeepSeek 侧调用只保留在 OpenAI 内核中。
|
||||
- **Tool Calling 双运行时对齐**:Go 侧(`internal/toolcall`)与 Vercel Node 侧(`internal/js/helpers/stream-tool-sieve`)保持一致的解析/防泄漏语义,覆盖 JSON / XML / invoke / text-kv 多风格输入。
|
||||
- **配置与运行时设置解耦**:静态配置(`config`)与运行时策略(`settings`)通过 Admin API 分离管理,支持热更新和密码轮换失效旧 JWT。
|
||||
- **流式能力升级**:`/v1/responses` 与 `/v1/chat/completions` 共享更一致的工具调用增量输出策略,降低不同 SDK 下的行为差异。
|
||||
- **可观测与可运维增强**:`/healthz`、`/readyz`、`/admin/version`、`/admin/dev/captures` 形成排障闭环,便于发布后验证。
|
||||
|
||||
## 核心能力
|
||||
|
||||
| 能力 | 说明 |
|
||||
| --- | --- |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings`、`POST /v1/files` |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) |
|
||||
| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) |
|
||||
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||
| DeepSeek PoW | 纯 Go 高性能实现(DeepSeekHashV1),毫秒级响应 |
|
||||
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、会话清理、导入导出、Vercel 同步、版本检查 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||
| Admin API | 配置管理、运行时设置热更新、代理管理、账号测试 / 批量测试、会话清理、导入导出、Vercel 同步、版本检查 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式,支持查看服务器端对话记录) |
|
||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||
|
||||
## 平台兼容矩阵
|
||||
@@ -137,7 +127,7 @@ flowchart LR
|
||||
| vision | `deepseek-vision-chat-search` | ❌ | ✅ |
|
||||
| vision | `deepseek-vision-reasoner-search` | ✅ | ✅ |
|
||||
|
||||
除原生模型外,也支持常见 alias 输入(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`、`gemini-2.5-pro` 等),但 `/v1/models` 返回的是规范化后的 DeepSeek 原生模型 ID。
|
||||
除原生模型外,也支持常见 alias 输入(如 `gpt-5`、`gpt-5-mini`、`gpt-5-codex`、`gpt-4.1`、`o3`、`claude-opus-4-6`、`claude-sonnet-4-5`、`gemini-2.5-pro`、`gemini-2.5-flash` 等),但 `/v1/models` 返回的是规范化后的 DeepSeek 原生模型 ID。
|
||||
|
||||
### Claude 接口(`GET /anthropic/v1/models`)
|
||||
|
||||
@@ -155,7 +145,7 @@ flowchart LR
|
||||
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请优先检查模型输出是否为受支持的 XML/Markup 工具块(例如 `<tool_call>` / `<function_call>` / `<invoke>` / `tool_use`),而不是纯 JSON `tool_calls` 片段。
|
||||
|
||||
### Gemini 接口
|
||||
|
||||
@@ -185,6 +175,8 @@ cp config.example.json config.json
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量,也可以直接写原始 JSON
|
||||
|
||||
WebUI 管理台里的“全量配置模板”也直接复用同一份 `config.example.json`,所以更新示例文件后,前端模板会自动保持一致。
|
||||
|
||||
### 方式一:下载 Release 构建包
|
||||
|
||||
每次发布 Release 时,GitHub Actions 会自动构建多平台二进制包:
|
||||
@@ -281,8 +273,17 @@ go run ./cmd/ds2api
|
||||
```json
|
||||
{
|
||||
"keys": ["your-api-key-1", "your-api-key-2"],
|
||||
"api_keys": [
|
||||
{
|
||||
"key": "your-api-key-1",
|
||||
"name": "主 Key",
|
||||
"remark": "生产流量"
|
||||
}
|
||||
],
|
||||
"accounts": [
|
||||
{
|
||||
"name": "账号 A",
|
||||
"remark": "主账号",
|
||||
"email": "user@example.com",
|
||||
"password": "your-password"
|
||||
},
|
||||
@@ -293,8 +294,12 @@ go run ./cmd/ds2api
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5": "deepseek-chat",
|
||||
"gpt-5-mini": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
"o3": "deepseek-reasoner",
|
||||
"claude-opus-4-6": "deepseek-reasoner",
|
||||
"gemini-2.5-flash": "deepseek-chat"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true,
|
||||
@@ -326,7 +331,8 @@ go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||
- `api_keys`:推荐使用的新结构化密钥列表,支持 `key` + `name` + `remark`(`keys` 仍兼容)
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录;可额外填写 `name` / `remark` 便于管理
|
||||
- `token`:配置文件中即使填写也会在加载时被清空(不会从 `config.json` 读取 token);实际 token 仅在运行时内存中维护并自动刷新
|
||||
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||
@@ -341,6 +347,8 @@ go run ./cmd/ds2api
|
||||
|
||||
### 环境变量
|
||||
|
||||
> 建议:长期维护请优先以 `config.json`(或其 Base64)为单一配置源。环境变量仅保留部署必需项;`DS2API_CONFIG_JSON` 主要用于 Vercel/无持久盘场景,后续可能进一步收敛。
|
||||
|
||||
| 变量 | 用途 | 默认值 |
|
||||
| --- | --- | --- |
|
||||
| `PORT` | 服务端口 | `5001` |
|
||||
@@ -350,6 +358,7 @@ go run ./cmd/ds2api
|
||||
| `DS2API_JWT_EXPIRE_HOURS` | Admin JWT 过期小时数 | `24` |
|
||||
| `DS2API_CONFIG_PATH` | 配置文件路径 | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | 直接注入配置(JSON 或 Base64) | — |
|
||||
| `DS2API_CHAT_HISTORY_PATH` | 服务器端对话记录文件路径 | `data/chat_history.json` |
|
||||
| `DS2API_ENV_WRITEBACK` | 环境变量模式下自动写回配置文件并切换文件模式(`1/true/yes/on`) | 关闭 |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | 管理台静态文件目录 | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | 启动时自动构建 WebUI | 本地开启,Vercel 关闭 |
|
||||
@@ -368,6 +377,15 @@ go run ./cmd/ds2api
|
||||
|
||||
> 提示:当检测到 `DS2API_CONFIG_JSON` 时,管理台会显示当前模式风险与自动持久化状态(含 `DS2API_CONFIG_PATH` 路径与模式切换说明)。
|
||||
|
||||
#### 必填 / 可选(按部署方式)
|
||||
|
||||
- **所有部署都必填**:`DS2API_ADMIN_KEY`
|
||||
- **配置来源二选一(推荐前者)**:
|
||||
- `config.json` 文件(推荐,持久化更直观)
|
||||
- `DS2API_CONFIG_JSON`(可选,适合 Vercel;支持 JSON 或 Base64)
|
||||
- **仅在环境变量配置模式建议开启**:`DS2API_ENV_WRITEBACK=1`(避免管理台改动重启后丢失)
|
||||
- 其余环境变量均为可选调优项。
|
||||
|
||||
## 鉴权模式
|
||||
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
|
||||
@@ -398,7 +416,7 @@ Gemini 路由还可以使用 `x-goog-api-key`,或在没有认证头时使用 `
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理与结构化转译:
|
||||
|
||||
1. 只在**非代码块上下文**启用执行型 toolcall 识别(代码块示例默认不触发)
|
||||
2. 解析层以 XML/Markup 为最高优先级,同时兼容 JSON / ANTML / invoke / text-kv,并统一归一到内部工具调用结构
|
||||
2. 解析层当前以 XML/Markup 家族为准(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体);纯 JSON `tool_calls` 片段默认不作为可执行调用解析
|
||||
3. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||
4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
|
||||
5. 客户端请求哪种协议,就按该协议返回工具调用(OpenAI/Claude/Gemini 各自原生结构);模型侧优先约束输出规范 XML,再由兼容层转译
|
||||
@@ -497,7 +515,7 @@ go test -v -run 'TestParseToolCalls|TestRepair' ./internal/toolcall/
|
||||
- **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发)
|
||||
- **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`)+ `sha256sums.txt`
|
||||
- **容器镜像发布**:仅推送到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、配置示例、README、LICENSE
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、`config.example.json` 配置示例、README、LICENSE
|
||||
|
||||
## 免责声明
|
||||
|
||||
|
||||
33
README.en.md
33
README.en.md
@@ -80,29 +80,19 @@ For the full module-by-module architecture and directory responsibilities, see [
|
||||
- **Frontend**: React admin panel (`webui/`), served as static build at runtime
|
||||
- **Deployment**: local run, Docker, Vercel serverless, Linux systemd
|
||||
|
||||
### 3.X Architecture Changes (vs older releases)
|
||||
|
||||
- **Unified routing core**: all protocol entries are now centralized through `internal/server/router.go`, with OpenAI / Claude / Gemini / Admin / WebUI routes registered in one tree to avoid multi-entry drift.
|
||||
- **Unified execution chain**: Claude/Gemini entries are translated by `internal/translatorcliproxy`, then executed through `openai.ChatCompletions` for shared tool-calling and stream semantics, then translated back to the client protocol.
|
||||
- **Cleaner adapter boundaries**: `internal/adapter/{claude,gemini}` handles protocol wrappers, while `internal/adapter/openai` remains the execution core; upstream DeepSeek calls are retained only in the OpenAI core.
|
||||
- **Tool-calling parity across runtimes**: Go (`internal/toolcall`) and Vercel Node (`internal/js/helpers/stream-tool-sieve`) follow aligned parsing/anti-leak semantics across JSON / XML / invoke / text-kv inputs.
|
||||
- **Config/runtime separation**: static config (`config`) and runtime policy (`settings`) are managed independently via Admin APIs, enabling hot updates and password rotation with JWT invalidation.
|
||||
- **Streaming behavior upgrade**: `/v1/responses` and `/v1/chat/completions` now share a more consistent incremental tool-call emission strategy across SDK ecosystems.
|
||||
- **Improved operability**: `/healthz`, `/readyz`, `/admin/version`, and `/admin/dev/captures` form a tighter post-deploy diagnostics loop.
|
||||
|
||||
## Key Capabilities
|
||||
|
||||
| Capability | Details |
|
||||
| --- | --- |
|
||||
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` |
|
||||
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings`, `POST /v1/files` |
|
||||
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` (plus shortcut paths `/v1/messages`, `/messages`) |
|
||||
| Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) |
|
||||
| Multi-account rotation | Auto token refresh, email/mobile dual login |
|
||||
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
||||
| DeepSeek PoW | Pure Go high-performance solver (DeepSeekHashV1), ms-level response |
|
||||
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
|
||||
| Admin API | Config management, runtime settings hot-reload, account testing/batch test, session cleanup, import/export, Vercel sync, version check |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
||||
| Admin API | Config management, runtime settings hot-reload, proxy management, account testing/batch test, session cleanup, import/export, Vercel sync, version check |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode, with server-side conversation history) |
|
||||
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
||||
|
||||
## Platform Compatibility Matrix
|
||||
@@ -135,7 +125,7 @@ For the full module-by-module architecture and directory responsibilities, see [
|
||||
| vision | `deepseek-vision-chat-search` | ❌ | ✅ |
|
||||
| vision | `deepseek-vision-reasoner-search` | ✅ | ✅ |
|
||||
|
||||
Besides native IDs, DS2API also accepts common aliases as input (for example `gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, `gemini-2.5-pro`), but `/v1/models` returns normalized DeepSeek native model IDs.
|
||||
Besides native IDs, DS2API also accepts common aliases as input (for example `gpt-5`, `gpt-5-mini`, `gpt-5-codex`, `gpt-4.1`, `o3`, `claude-opus-4-6`, `claude-sonnet-4-5`, `gemini-2.5-pro`, `gemini-2.5-flash`), but `/v1/models` returns normalized DeepSeek native model IDs.
|
||||
|
||||
### Claude Endpoint (`GET /anthropic/v1/models`)
|
||||
|
||||
@@ -153,7 +143,7 @@ Besides the current primary aliases above, `/anthropic/v1/models` also returns C
|
||||
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||
- If tool calls are rendered as plain text and not executed, first verify the model output uses supported XML/Markup tool blocks (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use`) rather than standalone JSON `tool_calls`.
|
||||
|
||||
### Gemini Endpoint
|
||||
|
||||
@@ -183,6 +173,8 @@ Recommended per deployment mode:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`, or paste raw JSON directly
|
||||
|
||||
The WebUI admin panel’s “Full configuration template” is loaded from the same `config.example.json`, so updating that file keeps the frontend template in sync.
|
||||
|
||||
### Option 1: Download Release Binaries
|
||||
|
||||
GitHub Actions automatically builds multi-platform archives on each Release:
|
||||
@@ -291,8 +283,12 @@ The server actually binds to `0.0.0.0:5001`, so devices on the same LAN can usua
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5": "deepseek-chat",
|
||||
"gpt-5-mini": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
"o3": "deepseek-reasoner",
|
||||
"claude-opus-4-6": "deepseek-reasoner",
|
||||
"gemini-2.5-flash": "deepseek-chat"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true,
|
||||
@@ -348,6 +344,7 @@ The server actually binds to `0.0.0.0:5001`, so devices on the same LAN can usua
|
||||
| `DS2API_JWT_EXPIRE_HOURS` | Admin JWT TTL in hours | `24` |
|
||||
| `DS2API_CONFIG_PATH` | Config file path | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | Inline config (JSON or Base64) | — |
|
||||
| `DS2API_CHAT_HISTORY_PATH` | Server-side conversation history file path | `data/chat_history.json` |
|
||||
| `DS2API_ENV_WRITEBACK` | Auto-write env-backed config to file and transition to file mode (`1/true/yes/on`) | Disabled |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | Admin static assets dir | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | Auto-build WebUI on startup | Enabled locally, disabled on Vercel |
|
||||
@@ -396,7 +393,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
2. The parser prioritizes XML/Markup, while also accepting JSON / ANTML / invoke / text-kv, and normalizes everything into the internal tool-call structure
|
||||
2. The parser currently targets XML/Markup-family tool syntax (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml variants); standalone JSON `tool_calls` payloads are not treated as executable calls by default
|
||||
3. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||
5. The output protocol follows the client request (OpenAI / Claude / Gemini native shapes); model-side prompting can prefer XML, and the compatibility layer handles the protocol-specific translation
|
||||
@@ -476,7 +473,7 @@ Workflow: `.github/workflows/release-artifacts.yml`
|
||||
- **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds)
|
||||
- **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), config template, README, LICENSE
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), `config.example.json`-based config template, README, LICENSE
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
@@ -5,14 +5,29 @@
|
||||
"your-api-key-1",
|
||||
"your-api-key-2"
|
||||
],
|
||||
"api_keys": [
|
||||
{
|
||||
"key": "your-api-key-1",
|
||||
"name": "主 API Key",
|
||||
"remark": "给 OpenAI 客户端使用"
|
||||
},
|
||||
{
|
||||
"key": "your-api-key-2",
|
||||
"name": "备用 API Key",
|
||||
"remark": "压测或临时调试"
|
||||
}
|
||||
],
|
||||
"accounts": [
|
||||
{
|
||||
"_comment": "邮箱登录方式",
|
||||
"name": "主账号",
|
||||
"remark": "优先用于生产流量",
|
||||
"email": "example1@example.com",
|
||||
"password": "your-password-1"
|
||||
},
|
||||
{
|
||||
"_comment": "邮箱登录方式 - 账号2",
|
||||
"name": "备用账号",
|
||||
"email": "example2@example.com",
|
||||
"password": "your-password-2"
|
||||
},
|
||||
@@ -34,6 +49,10 @@
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"history_split": {
|
||||
"enabled": true,
|
||||
"trigger_after_turns": 1
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
|
||||
@@ -116,7 +116,7 @@ flowchart LR
|
||||
- `internal/translatorcliproxy`: structure translation between Claude/Gemini and OpenAI.
|
||||
- `internal/deepseek`: upstream request/session/PoW/SSE handling.
|
||||
- `internal/stream` + `internal/sse`: stream parsing and incremental assembly.
|
||||
- `internal/toolcall`: JSON/XML/invoke/text-kv tool-call parsing + anti-leak sieve.
|
||||
- `internal/toolcall`: XML/Markup-family tool-call parsing + anti-leak sieve (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml variants).
|
||||
- `internal/admin`: config/accounts/vercel sync/version/dev-capture endpoints.
|
||||
- `internal/config`: config loading/validation + runtime settings hot-reload.
|
||||
- `internal/account`: managed account pool, inflight slots, waiting queue.
|
||||
|
||||
@@ -116,7 +116,7 @@ flowchart LR
|
||||
- `internal/translatorcliproxy`:Claude/Gemini 与 OpenAI 结构互转。
|
||||
- `internal/deepseek`:上游请求、会话、PoW、SSE 消费。
|
||||
- `internal/stream` + `internal/sse`:流式解析与增量处理。
|
||||
- `internal/toolcall`:JSON/XML/invoke/text-kv 工具调用解析及防泄漏筛分。
|
||||
- `internal/toolcall`:以 XML/Markup 家族为核心的工具调用解析与防泄漏筛分(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体)。
|
||||
- `internal/admin`:配置管理、账号管理、Vercel 同步、版本检查、开发抓包。
|
||||
- `internal/config`:配置加载、校验、运行时 settings 热更新。
|
||||
- `internal/account`:托管账号池、并发槽位、等待队列。
|
||||
|
||||
@@ -258,12 +258,22 @@ VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局并发上限 | `recommended_concurrency` |
|
||||
| `DS2API_ENV_WRITEBACK` | 检测到 `DS2API_CONFIG_JSON` 时自动写入 `DS2API_CONFIG_PATH`,并在成功后转为文件模式(`1/true/yes/on`) | 关闭 |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | 默认与 `responses.store_ttl_seconds` 同步,若未设置则为 `900` |
|
||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
|
||||
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
||||
| `DS2API_VERCEL_PROTECTION_BYPASS` | 部署保护绕过密钥(内部 Node→Go 调用) | — |
|
||||
|
||||
### 3.3 运行时行为配置(通过 Admin API 设置)
|
||||
|
||||
部分运行时行为无法通过环境变量直接配置,需要在部署后通过 Admin API 设置,例如:
|
||||
|
||||
- **自动删除会话模式** (`auto_delete.mode`):支持 `none` / `single` / `all`,默认为 `none`。可通过 `PUT /admin/settings` 更新。
|
||||
- **每账号并发上限** (`account_max_inflight`):环境变量已支持,但也可通过 Admin API 热更新。
|
||||
- **全局并发上限** (`global_max_inflight`):同上。
|
||||
|
||||
详细说明参见 [API.md](../API.md#admin-接口) 中 `/admin/settings` 部分。
|
||||
|
||||
### 3.3 Vercel 架构说明
|
||||
|
||||
```text
|
||||
|
||||
@@ -1,74 +1,74 @@
|
||||
# Tool call parsing semantics(Go/Node 统一语义)
|
||||
|
||||
本文档描述当前代码中 `ParseToolCallsDetailed` / `parseToolCallsDetailed` 的**实际行为**,用于对齐 Go 与 Node Runtime。
|
||||
本文档描述当前代码中工具调用解析链路的**实际行为**(以 `internal/toolcall` 与 `internal/js/helpers/stream-tool-sieve` 为准)。
|
||||
|
||||
文档导航:[总览](../README.MD) / [架构说明](./ARCHITECTURE.md) / [测试指南](./TESTING.md)
|
||||
|
||||
## 1) 输出结构(当前实现)
|
||||
## 1) 当前输出结构
|
||||
|
||||
- `calls`:解析得到的工具调用列表(`name` + `input`)。
|
||||
- `sawToolCallSyntax`:检测到工具调用语法特征时为 `true`(例如 `tool_calls`、`<tool_call>`、`<function_call>`、`<invoke>`、`function.name:`)。
|
||||
- `rejectedByPolicy`:当前实现固定为 `false`(预留字段,尚未启用 allow-list 拒绝)。
|
||||
`ParseToolCallsDetailed` / `parseToolCallsDetailed` 返回:
|
||||
|
||||
- `calls`:解析出的工具调用列表(`name` + `input`)。
|
||||
- `sawToolCallSyntax`:检测到工具调用语法特征时为 `true`。
|
||||
- `rejectedByPolicy`:当前实现固定为 `false`(预留字段)。
|
||||
- `rejectedToolNames`:当前实现固定为空数组(预留字段)。
|
||||
|
||||
> 说明:`filterToolCallsDetailed` 当前仅做结构清洗,不做工具名策略拒绝。
|
||||
> 当前 `filterToolCallsDetailed` 仅做结构清洗,不做 allow-list 工具名硬拒绝。
|
||||
|
||||
## 2) 解析管线
|
||||
## 2) 解析范围(重点)
|
||||
|
||||
1. **示例保护**:若判定为 fenced code block 示例上下文,则跳过执行型解析。
|
||||
2. **候选片段构建**:从完整文本中构建候选(原文、围绕 `tool_calls` 的 JSON 片段、首尾大括号切片等)。
|
||||
3. **按序尝试解析(命中即停)**:
|
||||
- 对“明显 JSON 工具载荷候选”(以 `{`/`[` 开头且包含 `tool_calls`/`\"function\"`)先走 JSON 解析,避免 JSON 字符串内偶发 XML 片段误命中;
|
||||
- 其余候选优先 XML 解析(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / `antml:function_call` 等);
|
||||
- JSON 解析(`{"tool_calls": [...]}`、列表、单对象);
|
||||
- Markup 解析;
|
||||
- Text-KV 回退(如 `function.name:` + `function.arguments:`)。
|
||||
4. **兜底**:候选全部失败后,再对全文做 XML / Text-KV 回退。
|
||||
当前版本的可执行解析以 **XML/Markup 家族**为主:
|
||||
|
||||
## 3) XML 能力边界(当前)
|
||||
- `<tool_call>...</tool_call>`
|
||||
- `<function_call>...</function_call>`
|
||||
- `<invoke ...>...</invoke>`(含自闭合)
|
||||
- `<tool_use>...</tool_use>`
|
||||
- antml 变体(如 `antml:function_call` / `antml:argument`)
|
||||
|
||||
当前已支持输入端的“多 XML/标记风格”解析,包括但不限于:
|
||||
并支持在这些标记块内部解析:
|
||||
|
||||
- `<tool_call><tool_name>...</tool_name><parameters>...</parameters></tool_call>`
|
||||
- `<function_call>tool</function_call><function parameter name="x">...</function parameter>`
|
||||
- `<invoke name="tool"><parameter name="x">...</parameter></invoke>`
|
||||
- `antml:function_call` / `antml:argument` / `antml:parameters`
|
||||
- `tool_use` 家族标签
|
||||
- JSON 参数字符串
|
||||
- 标签参数(`<parameter name="...">...`)
|
||||
- key/value 风格子标签
|
||||
|
||||
但**输出端仍统一转换为 OpenAI 兼容 JSON 事件/对象**(`message.tool_calls`、`delta.tool_calls`、`response.function_call_arguments.*`)。
|
||||
## 3) 不应再假设的行为
|
||||
|
||||
## 4) 关于“是否可以封装成 XML 再喂给模型”
|
||||
以下说法在当前实现中已不成立:
|
||||
|
||||
结论:**可以做,而且当前解析器已经能兼容 XML 作为输入格式之一**,但代码里并没有 `toolcall.prefer_xml_output` 这个开关。现有可调配置只有:
|
||||
1. “纯 JSON `tool_calls` 片段会被直接当作可执行工具调用解析”。
|
||||
2. “存在 `toolcall.mode` / `toolcall.early_emit_confidence` 等可配置开关可以改变解析策略”。
|
||||
|
||||
- `toolcall.mode`:`feature_match` / `off`
|
||||
- `toolcall.early_emit_confidence`:`high` / `low` / `off`
|
||||
当前策略在代码中固定为:
|
||||
|
||||
推荐思路仍然是“输入兼容层 + 输出按客户端协议渲染”:
|
||||
- 特征匹配开启(feature-match on)
|
||||
- 高置信度早发开启(early emit on)
|
||||
- policy 拒绝字段保留但未启用
|
||||
|
||||
1. **Prompt 约束层**:如果你要尝试 XML-first,可以在系统提示词里约束模型输出规范 XML tool block(例如 `<tool_calls><tool_call>...</tool_call></tool_calls>`)。
|
||||
2. **解析兼容层**:继续在 parser 中同时接受 JSON / XML / ANTML / invoke / text-kv。
|
||||
3. **协议归一层**:无论模型输出什么格式,统一落到内部 `ParsedToolCall`。
|
||||
4. **对外渲染层**:根据客户端请求协议渲染(OpenAI / Claude / Gemini 各自格式)。
|
||||
## 4) 流式与防泄漏语义
|
||||
|
||||
这样可以同时获得:
|
||||
在流式链路中(OpenAI / Claude / Gemini 统一内核):
|
||||
|
||||
- 减少模型端 JSON 转义/引号错误;
|
||||
- 不破坏现有 SDK / 客户端生态;
|
||||
- 逐步灰度(按模型、按租户、按请求开关)。
|
||||
- 工具调用片段会被优先提取为结构化增量输出;
|
||||
- 已识别的工具调用原始片段不会作为普通文本再次回流;
|
||||
- fenced code block 中的示例内容按文本处理,不作为可执行工具调用。
|
||||
|
||||
## 5) 落地建议(低风险迭代)
|
||||
## 5) 落地建议(按当前实现)
|
||||
|
||||
- 继续使用现有的 `toolcall.mode=feature_match` 和 `toolcall.early_emit_confidence=high` 作为默认策略。
|
||||
- 如果要试 XML-first,把它放在 prompt 层或上游模板层,不要假设代码里已有专门的 XML 输出开关。
|
||||
- 增加观测指标:
|
||||
- `toolcall_parse_source`(json/xml/markup/textkv);
|
||||
- `toolcall_parse_success_rate`;
|
||||
- `toolcall_malformed_rate`;
|
||||
- `toolcall_repair_rate`。
|
||||
- 先在 `responses` 链路灰度,再扩展 `chat.completions`。
|
||||
1. Prompt 里优先约束模型输出 XML/Markup 工具块。
|
||||
2. 执行器侧继续做工具名白名单与参数 schema 校验(不要依赖 parser 代替安全策略)。
|
||||
3. 需要兼容历史“纯 JSON tool_calls”模型输出时,请在上游模板层把输出规范化为 XML/Markup 风格再进入 DS2API。
|
||||
|
||||
## 6) 兼容性提醒
|
||||
## 6) 回归验证建议
|
||||
|
||||
- 上游模型若输出混合文本 + XML,仍可能出现“半结构化”噪声,需要依赖现有 sieve 增量消费策略。
|
||||
- XML 不等于安全:仍需做 tool 名、参数 schema、执行权限的服务端校验。
|
||||
可直接运行:
|
||||
|
||||
```bash
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/toolcall/
|
||||
node --test tests/node/stream-tool-sieve.test.js
|
||||
```
|
||||
|
||||
重点覆盖:
|
||||
|
||||
- `<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体
|
||||
- 参数 JSON 修复与解析
|
||||
- 流式增量下的工具调用提取与文本防泄漏
|
||||
|
||||
@@ -138,77 +138,6 @@ func TestHandleClaudeStreamRealtimeThinkingDelta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["type"] == "text_delta" && strings.Contains(asString(delta["text"]), `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in text delta: body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block in stream, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundToolUseStop := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
foundToolUseStop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUseStop {
|
||||
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolDetectionFromThinkingFallback(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "search" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block from thinking fallback, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
if !containsStr(content, "<tool_calls>") || !containsStr(content, "<tool_name>search_web</tool_name>") {
|
||||
t.Fatalf("expected assistant content to include XML tool call history, got %q", content)
|
||||
}
|
||||
if !containsStr(content, `<parameters>{"query":"latest"}</parameters>`) {
|
||||
if !containsStr(content, "<parameters>\n <query><![CDATA[latest]]></query>\n </parameters>") {
|
||||
t.Fatalf("expected assistant content to include serialized parameters, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
250
internal/adapter/openai/chat_history.go
Normal file
250
internal/adapter/openai/chat_history.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/prompt"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
const adminWebUISourceHeader = "X-Ds2-Source"
|
||||
const adminWebUISourceValue = "admin-webui-api-tester"
|
||||
|
||||
type chatHistorySession struct {
|
||||
store *chathistory.Store
|
||||
entryID string
|
||||
startedAt time.Time
|
||||
lastPersist time.Time
|
||||
finalPrompt string
|
||||
startParams chathistory.StartParams
|
||||
disabled bool
|
||||
}
|
||||
|
||||
func startChatHistory(store *chathistory.Store, r *http.Request, a *auth.RequestAuth, stdReq util.StandardRequest) *chatHistorySession {
|
||||
if store == nil || r == nil || a == nil {
|
||||
return nil
|
||||
}
|
||||
if !store.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if !shouldCaptureChatHistory(r) {
|
||||
return nil
|
||||
}
|
||||
entry, err := store.Start(chathistory.StartParams{
|
||||
CallerID: strings.TrimSpace(a.CallerID),
|
||||
AccountID: strings.TrimSpace(a.AccountID),
|
||||
Model: strings.TrimSpace(stdReq.ResponseModel),
|
||||
Stream: stdReq.Stream,
|
||||
UserInput: extractSingleUserInput(stdReq.Messages),
|
||||
Messages: extractAllMessages(stdReq.Messages),
|
||||
FinalPrompt: stdReq.FinalPrompt,
|
||||
})
|
||||
startParams := chathistory.StartParams{
|
||||
CallerID: strings.TrimSpace(a.CallerID),
|
||||
AccountID: strings.TrimSpace(a.AccountID),
|
||||
Model: strings.TrimSpace(stdReq.ResponseModel),
|
||||
Stream: stdReq.Stream,
|
||||
UserInput: extractSingleUserInput(stdReq.Messages),
|
||||
Messages: extractAllMessages(stdReq.Messages),
|
||||
FinalPrompt: stdReq.FinalPrompt,
|
||||
}
|
||||
session := &chatHistorySession{
|
||||
store: store,
|
||||
entryID: entry.ID,
|
||||
startedAt: time.Now(),
|
||||
lastPersist: time.Now(),
|
||||
finalPrompt: stdReq.FinalPrompt,
|
||||
startParams: startParams,
|
||||
}
|
||||
if err != nil {
|
||||
if entry.ID == "" {
|
||||
config.Logger.Warn("[chat_history] start failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
config.Logger.Warn("[chat_history] start persisted in memory after write failure", "error", err)
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func shouldCaptureChatHistory(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) || isVercelStreamReleaseRequest(r) {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.Header.Get(adminWebUISourceHeader)) != adminWebUISourceValue
|
||||
}
|
||||
|
||||
func extractSingleUserInput(messages []any) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "user" {
|
||||
continue
|
||||
}
|
||||
if normalized := strings.TrimSpace(prompt.NormalizeContent(msg["content"])); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractAllMessages(messages []any) []chathistory.Message {
|
||||
out := make([]chathistory.Message, 0, len(messages))
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
content := strings.TrimSpace(prompt.NormalizeContent(msg["content"]))
|
||||
if role == "" || content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, chathistory.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) progress(thinking, content string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
if now.Sub(s.lastPersist) < 250*time.Millisecond {
|
||||
return
|
||||
}
|
||||
s.lastPersist = now
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "streaming",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) success(statusCode int, thinking, content, finishReason string, usage map[string]any) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
Completed: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) error(statusCode int, message, finishReason, thinking, content string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "error",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
Error: message,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Completed: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) stopped(thinking, content, finishReason string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "stopped",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: openaifmt.BuildChatUsage(s.finalPrompt, thinking, content),
|
||||
Completed: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) retryMissingEntry() bool {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return false
|
||||
}
|
||||
entry, err := s.store.Start(s.startParams)
|
||||
if errors.Is(err, chathistory.ErrDisabled) {
|
||||
s.disabled = true
|
||||
return false
|
||||
}
|
||||
if entry.ID == "" {
|
||||
if err != nil {
|
||||
config.Logger.Warn("[chat_history] recreate missing entry failed", "error", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
s.entryID = entry.ID
|
||||
if err != nil {
|
||||
config.Logger.Warn("[chat_history] recreate missing entry persisted in memory after write failure", "error", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) persistUpdate(params chathistory.UpdateParams) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
if _, err := s.store.Update(s.entryID, params); err != nil {
|
||||
s.handlePersistError(params, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) handlePersistError(params chathistory.UpdateParams, err error) {
|
||||
if err == nil || s == nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, chathistory.ErrDisabled) {
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
if isChatHistoryMissingError(err) {
|
||||
if s.retryMissingEntry() {
|
||||
if _, retryErr := s.store.Update(s.entryID, params); retryErr != nil {
|
||||
if errors.Is(retryErr, chathistory.ErrDisabled) || isChatHistoryMissingError(retryErr) {
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
config.Logger.Warn("[chat_history] retry after missing entry failed", "error", retryErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
config.Logger.Warn("[chat_history] update failed", "error", err)
|
||||
}
|
||||
|
||||
func isChatHistoryMissingError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "not found")
|
||||
}
|
||||
273
internal/adapter/openai/chat_history_test.go
Normal file
273
internal/adapter/openai/chat_history_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func newTestChatHistoryStore(t *testing.T) *chathistory.Store {
|
||||
t.Helper()
|
||||
store := chathistory.New(filepath.Join(t.TempDir(), "chat_history.json"))
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("chat history store unavailable: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func blockChatHistoryDetailDir(t *testing.T, detailDir string) func() {
|
||||
t.Helper()
|
||||
blockedDir := detailDir + ".blocked"
|
||||
if err := os.RemoveAll(blockedDir); err != nil {
|
||||
t.Fatalf("remove blocked detail dir failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(detailDir, blockedDir); err != nil {
|
||||
t.Fatalf("move detail dir aside failed: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocked detail path failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(detailDir, []byte("blocked"), 0o644); err != nil {
|
||||
t.Fatalf("write blocked detail path failed: %v", err)
|
||||
}
|
||||
var once sync.Once
|
||||
return func() {
|
||||
t.Helper()
|
||||
once.Do(func() {
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocking detail path failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(blockedDir, detailDir); err != nil {
|
||||
t.Fatalf("restore detail dir failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsNonStreamPersistsHistory(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"system","content":"be precise"},{"role":"user","content":"hi there"},{"role":"assistant","content":"previous answer"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
item := snapshot.Items[0]
|
||||
if item.Status != "success" || item.UserInput != "hi there" {
|
||||
t.Fatalf("unexpected persisted history summary: %#v", item)
|
||||
}
|
||||
full, err := historyStore.Get(item.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected detail item, got %v", err)
|
||||
}
|
||||
if full.Content != "hello world" {
|
||||
t.Fatalf("expected detail content persisted, got %#v", full)
|
||||
}
|
||||
if len(full.Messages) != 3 {
|
||||
t.Fatalf("expected all request messages persisted, got %#v", full.Messages)
|
||||
}
|
||||
if full.FinalPrompt == "" {
|
||||
t.Fatalf("expected final prompt to be persisted")
|
||||
}
|
||||
if item.CallerID != "caller:test" {
|
||||
t.Fatalf("expected caller hash persisted in summary, got %#v", item.CallerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartChatHistoryRecoversFromTransientWriteFailure(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
restore := blockChatHistoryDetailDir(t, historyStore.DetailDir())
|
||||
t.Cleanup(restore)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
a := &auth.RequestAuth{
|
||||
CallerID: "caller:test",
|
||||
AccountID: "acct:test",
|
||||
}
|
||||
stdReq := util.StandardRequest{
|
||||
ResponseModel: "deepseek-chat",
|
||||
Stream: true,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
FinalPrompt: "hello",
|
||||
}
|
||||
|
||||
session := startChatHistory(historyStore, req, a, stdReq)
|
||||
if session == nil {
|
||||
t.Fatalf("expected session even when initial persistence fails")
|
||||
}
|
||||
if session.disabled {
|
||||
t.Fatalf("expected session to remain active after transient start failure")
|
||||
}
|
||||
if session.entryID == "" {
|
||||
t.Fatalf("expected session entry id to be retained")
|
||||
}
|
||||
if err := historyStore.Err(); err != nil {
|
||||
t.Fatalf("transient start failure should not latch store error: %v", err)
|
||||
}
|
||||
|
||||
session.lastPersist = time.Now().Add(-time.Second)
|
||||
session.progress("thinking", "partial")
|
||||
if session.disabled {
|
||||
t.Fatalf("expected session to remain active after transient update failure")
|
||||
}
|
||||
if session.entryID == "" {
|
||||
t.Fatalf("expected session entry id to remain set after update failure")
|
||||
}
|
||||
if err := historyStore.Err(); err != nil {
|
||||
t.Fatalf("transient update failure should not latch store error: %v", err)
|
||||
}
|
||||
|
||||
restore()
|
||||
|
||||
session.success(http.StatusOK, "thinking", "final answer", "stop", map[string]any{"total_tokens": 7})
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed after restore: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one persisted item after restore, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := historyStore.Get(session.entryID)
|
||||
if err != nil {
|
||||
t.Fatalf("get restored entry failed: %v", err)
|
||||
}
|
||||
if full.Status != "success" || full.Content != "final answer" {
|
||||
t.Fatalf("expected restored entry to persist final success, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
entry, err := historyStore.Start(chathistory.StartParams{
|
||||
CallerID: "caller:test",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start history failed: %v", err)
|
||||
}
|
||||
session := &chatHistorySession{
|
||||
store: historyStore,
|
||||
entryID: entry.ID,
|
||||
startedAt: time.Now(),
|
||||
lastPersist: time.Now(),
|
||||
finalPrompt: "hello",
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil).WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-chat", "prompt", false, false, nil, session)
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
full, err := historyStore.Get(snapshot.Items[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get detail failed: %v", err)
|
||||
}
|
||||
if full.Status != "stopped" {
|
||||
t.Fatalf("expected stopped status, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsSkipsAdminWebUISource(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi there"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(adminWebUISourceHeader, adminWebUISourceValue)
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected admin webui source to be skipped, got %#v", snapshot.Items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsSkipsHistoryWhenDisabled(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
if _, err := historyStore.SetLimit(chathistory.DisabledLimit); err != nil {
|
||||
t.Fatalf("disable history store failed: %v", err)
|
||||
}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi there"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected disabled history to stay empty, got %#v", snapshot.Items)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,14 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
finalThinking string
|
||||
finalText string
|
||||
finalFinishReason string
|
||||
finalUsage map[string]any
|
||||
finalErrorStatus int
|
||||
finalErrorMessage string
|
||||
finalErrorCode string
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -99,6 +107,9 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
|
||||
s.finalErrorStatus = status
|
||||
s.finalErrorMessage = message
|
||||
s.finalErrorCode = code
|
||||
s.sendChunk(map[string]any{
|
||||
"status_code": status,
|
||||
"error": map[string]any{
|
||||
@@ -111,9 +122,16 @@ func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) resetStreamToolCallState() {
|
||||
s.streamToolCallIDs = map[int]string{}
|
||||
s.streamToolNames = map[int]string{}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
s.finalThinking = finalThinking
|
||||
s.finalText = finalText
|
||||
detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
@@ -153,6 +171,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||
nil,
|
||||
))
|
||||
s.resetStreamToolCallState()
|
||||
}
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
@@ -197,6 +216,8 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
return
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
s.finalFinishReason = finishReason
|
||||
s.finalUsage = usage
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
@@ -294,6 +315,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
s.resetStreamToolCallState()
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
|
||||
31
internal/adapter/openai/citation_links.go
Normal file
31
internal/adapter/openai/citation_links.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var citationMarkerPattern = regexp.MustCompile(`(?i)\[citation:\s*(\d+)\]`)
|
||||
|
||||
func replaceCitationMarkersWithLinks(text string, links map[int]string) string {
|
||||
if strings.TrimSpace(text) == "" || len(links) == 0 {
|
||||
return text
|
||||
}
|
||||
return citationMarkerPattern.ReplaceAllStringFunc(text, func(match string) string {
|
||||
sub := citationMarkerPattern.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(sub[1]))
|
||||
if err != nil || idx <= 0 {
|
||||
return match
|
||||
}
|
||||
url := strings.TrimSpace(links[idx])
|
||||
if url == "" {
|
||||
return match
|
||||
}
|
||||
return fmt.Sprintf("[%d](%s)", idx, url)
|
||||
})
|
||||
}
|
||||
28
internal/adapter/openai/citation_links_test.go
Normal file
28
internal/adapter/openai/citation_links_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestReplaceCitationMarkersWithLinks(t *testing.T) {
|
||||
raw := "这是一条更新[citation:1],更多信息见[citation:2]。"
|
||||
links := map[int]string{
|
||||
1: "https://example.com/news-1",
|
||||
2: "https://example.com/news-2",
|
||||
}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "这是一条更新[1](https://example.com/news-1),更多信息见[2](https://example.com/news-2)。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceCitationMarkersWithLinksKeepsUnknownIndex(t *testing.T) {
|
||||
raw := "只有一个来源[citation:1],未知来源[citation:3]。"
|
||||
links := map[int]string{1: "https://example.com/a"}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "只有一个来源[1](https://example.com/a),未知来源[citation:3]。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
@@ -34,6 +34,8 @@ type ConfigReader interface {
|
||||
EmbeddingsProvider() string
|
||||
AutoDeleteMode() string
|
||||
AutoDeleteSessions() bool
|
||||
HistorySplitEnabled() bool
|
||||
HistorySplitTriggerAfterTurns() int
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
|
||||
@@ -3,13 +3,15 @@ package openai
|
||||
import "testing"
|
||||
|
||||
type mockOpenAIConfig struct {
|
||||
aliases map[string]string
|
||||
wideInput bool
|
||||
autoDeleteMode string
|
||||
toolMode string
|
||||
earlyEmit string
|
||||
responsesTTL int
|
||||
embedProv string
|
||||
aliases map[string]string
|
||||
wideInput bool
|
||||
autoDeleteMode string
|
||||
toolMode string
|
||||
earlyEmit string
|
||||
responsesTTL int
|
||||
embedProv string
|
||||
historySplitEnabled bool
|
||||
historySplitTurns int
|
||||
}
|
||||
|
||||
func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases }
|
||||
@@ -27,7 +29,14 @@ func (m mockOpenAIConfig) AutoDeleteMode() string {
|
||||
}
|
||||
return m.autoDeleteMode
|
||||
}
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
func (m mockOpenAIConfig) HistorySplitEnabled() bool { return m.historySplitEnabled }
|
||||
func (m mockOpenAIConfig) HistorySplitTriggerAfterTurns() int {
|
||||
if m.historySplitTurns <= 0 {
|
||||
return 1
|
||||
}
|
||||
return m.historySplitTurns
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
cfg := mockOpenAIConfig{
|
||||
|
||||
@@ -63,32 +63,50 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq, err = h.applyHistorySplit(r.Context(), a, stdReq)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
historySession := startChatHistory(h.ChatHistory, r, a, stdReq)
|
||||
|
||||
sessionID, err = h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
h.handleNonStream(w, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -124,30 +142,52 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.error(resp.StatusCode, string(body), "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
|
||||
finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
|
||||
if writeUpstreamEmptyOutputError(w, finalText, result.ContentFilter) {
|
||||
if searchEnabled {
|
||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||
}
|
||||
if shouldWriteUpstreamEmptyOutputError(finalText) {
|
||||
status, message, code := upstreamEmptyOutputDetail(result.ContentFilter, finalText, finalThinking)
|
||||
if historySession != nil {
|
||||
historySession.error(status, message, code, finalThinking, finalText)
|
||||
}
|
||||
writeUpstreamEmptyOutputError(w, finalText, result.ContentFilter)
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
||||
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
||||
finishReason = fr
|
||||
}
|
||||
}
|
||||
if historySession != nil {
|
||||
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText))
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.error(resp.StatusCode, string(body), "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
@@ -198,13 +238,32 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnParsed: func(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
decision := streamRuntime.onParsed(parsed)
|
||||
if historySession != nil {
|
||||
historySession.progress(streamRuntime.thinking.String(), streamRuntime.text.String())
|
||||
}
|
||||
return decision
|
||||
},
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
} else {
|
||||
streamRuntime.finalize("stop")
|
||||
}
|
||||
if historySession == nil {
|
||||
return
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
if streamRuntime.finalErrorMessage != "" {
|
||||
historySession.error(streamRuntime.finalErrorStatus, streamRuntime.finalErrorMessage, streamRuntime.finalErrorCode, streamRuntime.thinking.String(), streamRuntime.text.String())
|
||||
return
|
||||
}
|
||||
historySession.success(http.StatusOK, streamRuntime.finalThinking, streamRuntime.finalText, streamRuntime.finalFinishReason, streamRuntime.finalUsage)
|
||||
},
|
||||
OnContextDone: func() {
|
||||
if historySession != nil {
|
||||
historySession.stopped(streamRuntime.thinking.String(), streamRuntime.text.String(), string(streamengine.StopReasonContextCancelled))
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
@@ -25,9 +26,10 @@ const (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
ChatHistory *chathistory.Store
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -59,21 +57,6 @@ func parseSSEDataFrames(t *testing.T, body string) ([]map[string]any, bool) {
|
||||
return frames, done
|
||||
}
|
||||
|
||||
func streamHasRawToolJSONContent(frames []map[string]any) bool {
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
content, _ := delta["content"].(string)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func streamHasToolCallsDelta(frames []map[string]any) bool {
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
@@ -101,180 +84,7 @@ func streamFinishReason(frames []map[string]any) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func streamToolCallArgumentChunks(frames []map[string]any) []string {
|
||||
out := make([]string, 0, 4)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
if args, ok := fn["arguments"].(string); ok && args != "" {
|
||||
out = append(out, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
if len(choices) != 1 {
|
||||
t.Fatalf("unexpected choices: %#v", out["choices"])
|
||||
}
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
}
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"先想一下"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["reasoning_content"] != "先想一下" {
|
||||
t.Fatalf("expected reasoning_content, got %#v", msg["reasoning_content"])
|
||||
}
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
}
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] == "tool_calls" {
|
||||
t.Fatalf("expected fenced example to remain content-only, got finish_reason=%#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 0 {
|
||||
t.Fatalf("expected no tool_call field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced example content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -283,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-chat", "prompt", false, false, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -302,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, false, nil, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -321,7 +131,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, nil)
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, false, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -332,193 +142,6 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid3", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
large := strings.Repeat("a", 9000)
|
||||
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
|
||||
splitAt := len(payload) / 2
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"思考中"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid4", "deepseek-reasoner", "prompt", true, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
hasThinkingDelta := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if _, ok := delta["reasoning_content"]; ok {
|
||||
hasThinkingDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasThinkingDelta {
|
||||
t.Fatalf("expected reasoning_content delta in reasoner stream: %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid5", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid5b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -529,7 +152,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-chat", "prompt", false, false, []string{"search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -557,287 +180,6 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "```json") || strings.Contains(got, "\n```\n") {
|
||||
t.Fatalf("expected consumed fenced tool payload to not leave empty code fence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamStandaloneToolCallAfterClosedFenceKeepsFence(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "先给一个代码示例:\n```text\nhello\n```\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7g", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for standalone payload, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```") {
|
||||
t.Fatalf("expected closed fence before standalone tool json to be preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
spaces := strings.Repeat(" ", 200)
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{`+spaces+`"}`,
|
||||
`data: {"p":"response/content","v":"\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文C。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid8", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置正文D。"}`,
|
||||
`data: {"p":"response/content","v":"{'tool_calls':[{'name':'search','input':{'q':'go'}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文E。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid9", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for invalid json, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
|
||||
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -847,7 +189,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-chat", "prompt", false, false, []string{"search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -872,107 +214,50 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||
func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
|
||||
`data: {"p":"response/content","v":"前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
|
||||
`data: {"p":"response/content","v":"中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) == 0 {
|
||||
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
t.Fatalf("unexpected merged arguments stream: %q", joined)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
|
||||
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundSearch := false
|
||||
foundEval := false
|
||||
foundIndex1 := false
|
||||
toolCallsDeltaLens := make([]int, 0, 2)
|
||||
ids := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
rawToolCalls, hasToolCalls := delta["tool_calls"]
|
||||
if !hasToolCalls {
|
||||
continue
|
||||
}
|
||||
toolCalls, _ := rawToolCalls.([]any)
|
||||
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
|
||||
foundIndex1 = true
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, rawCall := range toolCalls {
|
||||
call, _ := rawCall.(map[string]any)
|
||||
id := asString(call["id"])
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
name, _ := fn["name"].(string)
|
||||
switch name {
|
||||
case "search_web":
|
||||
foundSearch = true
|
||||
case "eval_javascript":
|
||||
foundEval = true
|
||||
case "search_webeval_javascript":
|
||||
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
|
||||
}
|
||||
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSearch || !foundEval {
|
||||
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
|
||||
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected two distinct tool call ids, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
|
||||
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
|
||||
}
|
||||
if !foundIndex1 {
|
||||
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if ids[0] == ids[1] {
|
||||
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
289
internal/adapter/openai/history_split.go
Normal file
289
internal/adapter/openai/history_split.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
historySplitFilename = "HISTORY.txt"
|
||||
historySplitContentType = "text/plain; charset=utf-8"
|
||||
historySplitPurpose = "assistants"
|
||||
)
|
||||
|
||||
func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, stdReq util.StandardRequest) (util.StandardRequest, error) {
|
||||
if h == nil || h.DS == nil || h.Store == nil || a == nil {
|
||||
return stdReq, nil
|
||||
}
|
||||
if !h.Store.HistorySplitEnabled() {
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
promptMessages, historyMessages := splitOpenAIHistoryMessages(stdReq.Messages, h.Store.HistorySplitTriggerAfterTurns())
|
||||
if len(historyMessages) == 0 {
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
reasoningContent := extractHistorySplitReasoningContent(historyMessages)
|
||||
historyText := buildOpenAIHistoryTranscript(historyMessages)
|
||||
if strings.TrimSpace(historyText) == "" {
|
||||
return stdReq, errors.New("history split produced empty transcript")
|
||||
}
|
||||
|
||||
result, err := h.DS.UploadFile(ctx, a, deepseek.UploadFileRequest{
|
||||
Filename: historySplitFilename,
|
||||
ContentType: historySplitContentType,
|
||||
Purpose: historySplitPurpose,
|
||||
Data: []byte(historyText),
|
||||
}, 3)
|
||||
if err != nil {
|
||||
return stdReq, fmt.Errorf("upload history file: %w", err)
|
||||
}
|
||||
fileID := strings.TrimSpace(result.ID)
|
||||
if fileID == "" {
|
||||
return stdReq, errors.New("upload history file returned empty file id")
|
||||
}
|
||||
|
||||
stdReq.Messages = promptMessages
|
||||
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
|
||||
stdReq.FinalPrompt, stdReq.ToolNames = buildHistorySplitPrompt(promptMessages, reasoningContent, stdReq.ToolsRaw, stdReq.ToolChoice, stdReq.Thinking)
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
func buildHistorySplitPrompt(messages []any, reasoningContent string, toolsRaw any, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) {
|
||||
if len(messages) == 0 && strings.TrimSpace(reasoningContent) == "" {
|
||||
return "", nil
|
||||
}
|
||||
instruction := historySplitPromptInstruction(thinkingEnabled)
|
||||
withInstruction := make([]any, 0, len(messages)+1)
|
||||
withInstruction = append(withInstruction, map[string]any{
|
||||
"role": "system",
|
||||
"content": instruction,
|
||||
})
|
||||
withInstruction = append(withInstruction, injectHistorySplitReasoningMessage(messages, reasoningContent)...)
|
||||
return buildOpenAIFinalPromptWithPolicy(withInstruction, toolsRaw, "", toolPolicy, false)
|
||||
}
|
||||
|
||||
func historySplitPromptInstruction(thinkingEnabled bool) string {
|
||||
lines := []string{
|
||||
"Follow the instructions in this prompt first. If earlier conversation instructions conflict with this prompt, this prompt wins.",
|
||||
"An attached HISTORY.txt file contains prior conversation history and tool progress; read it first, then answer the latest user request using that history as context.",
|
||||
"Continue the conversation from the full prior context and the latest tool results.",
|
||||
"Treat earlier messages as binding context; answer the user's current request as a continuation, not a restart.",
|
||||
}
|
||||
if thinkingEnabled {
|
||||
lines = append(lines, "Keep reasoning internal. Do not leave the final user-facing answer only in reasoning; always provide the answer in visible assistant content.")
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func splitOpenAIHistoryMessages(messages []any, triggerAfterTurns int) ([]any, []any) {
|
||||
if triggerAfterTurns <= 0 {
|
||||
triggerAfterTurns = 1
|
||||
}
|
||||
lastUserIndex := -1
|
||||
userTurns := 0
|
||||
for i, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "user" {
|
||||
continue
|
||||
}
|
||||
userTurns++
|
||||
lastUserIndex = i
|
||||
}
|
||||
if userTurns <= triggerAfterTurns || lastUserIndex < 0 {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
promptMessages := make([]any, 0, len(messages)-lastUserIndex)
|
||||
historyMessages := make([]any, 0, lastUserIndex)
|
||||
for i, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
if i >= lastUserIndex {
|
||||
promptMessages = append(promptMessages, raw)
|
||||
} else {
|
||||
historyMessages = append(historyMessages, raw)
|
||||
}
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
promptMessages = append(promptMessages, raw)
|
||||
default:
|
||||
if i >= lastUserIndex {
|
||||
promptMessages = append(promptMessages, raw)
|
||||
} else {
|
||||
historyMessages = append(historyMessages, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(promptMessages) == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
return promptMessages, historyMessages
|
||||
}
|
||||
|
||||
func buildOpenAIHistoryTranscript(messages []any) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("# HISTORY.txt\n")
|
||||
b.WriteString("Prior conversation history and tool progress.\n\n")
|
||||
|
||||
entry := 0
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
content := buildOpenAIHistoryEntry(role, msg)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
continue
|
||||
}
|
||||
entry++
|
||||
fmt.Fprintf(&b, "=== %d. %s ===\n%s\n\n", entry, strings.ToUpper(roleLabelForHistory(role)), content)
|
||||
}
|
||||
return strings.TrimSpace(b.String()) + "\n"
|
||||
}
|
||||
|
||||
func buildOpenAIHistoryEntry(role string, msg map[string]any) string {
|
||||
switch role {
|
||||
case "assistant":
|
||||
return strings.TrimSpace(buildAssistantHistoryContent(msg))
|
||||
case "tool", "function":
|
||||
return strings.TrimSpace(buildToolHistoryContent(msg))
|
||||
case "user":
|
||||
return strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
default:
|
||||
return strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
}
|
||||
}
|
||||
|
||||
func buildAssistantHistoryContent(msg map[string]any) string {
|
||||
return strings.TrimSpace(buildAssistantContentForPrompt(msg))
|
||||
}
|
||||
|
||||
func buildToolHistoryContent(msg map[string]any) string {
|
||||
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
parts := make([]string, 0, 2)
|
||||
if name := strings.TrimSpace(asString(msg["name"])); name != "" {
|
||||
parts = append(parts, "name="+name)
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(msg["tool_call_id"])); callID != "" {
|
||||
parts = append(parts, "tool_call_id="+callID)
|
||||
}
|
||||
header := ""
|
||||
if len(parts) > 0 {
|
||||
header = "[" + strings.Join(parts, " ") + "]"
|
||||
}
|
||||
switch {
|
||||
case header != "" && content != "":
|
||||
return header + "\n" + content
|
||||
case header != "":
|
||||
return header
|
||||
default:
|
||||
return content
|
||||
}
|
||||
}
|
||||
|
||||
func extractHistorySplitReasoningContent(messages []any) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "assistant" {
|
||||
continue
|
||||
}
|
||||
reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
|
||||
if reasoning == "" {
|
||||
reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
|
||||
}
|
||||
if reasoning != "" {
|
||||
return reasoning
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func injectHistorySplitReasoningMessage(messages []any, reasoningContent string) []any {
|
||||
reasoningContent = strings.TrimSpace(reasoningContent)
|
||||
if reasoningContent == "" {
|
||||
return messages
|
||||
}
|
||||
reasoningMsg := map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": reasoningContent,
|
||||
}
|
||||
lastUserIndex := lastOpenAIUserMessageIndex(messages)
|
||||
if lastUserIndex < 0 {
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
out = append(out, reasoningMsg)
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
for i, raw := range messages {
|
||||
if i == lastUserIndex {
|
||||
out = append(out, reasoningMsg)
|
||||
}
|
||||
out = append(out, raw)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func lastOpenAIUserMessageIndex(messages []any) int {
|
||||
last := -1
|
||||
for i, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) == "user" {
|
||||
last = i
|
||||
}
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func roleLabelForHistory(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
switch role {
|
||||
case "function":
|
||||
return "tool"
|
||||
case "":
|
||||
return "unknown"
|
||||
default:
|
||||
return role
|
||||
}
|
||||
}
|
||||
|
||||
func prependUniqueRefFileID(existing []string, fileID string) []string {
|
||||
fileID = strings.TrimSpace(fileID)
|
||||
if fileID == "" {
|
||||
return existing
|
||||
}
|
||||
out := make([]string, 0, len(existing)+1)
|
||||
out = append(out, fileID)
|
||||
for _, id := range existing {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" || strings.EqualFold(trimmed, fileID) {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
return out
|
||||
}
|
||||
353
internal/adapter/openai/history_split_test.go
Normal file
353
internal/adapter/openai/history_split_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func historySplitTestMessages() []any {
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"arguments": map[string]any{"query": "docs"},
|
||||
},
|
||||
}
|
||||
return []any{
|
||||
map[string]any{"role": "system", "content": "system instructions"},
|
||||
map[string]any{"role": "user", "content": "first user turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "hidden reasoning",
|
||||
"tool_calls": toolCalls,
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"name": "search",
|
||||
"tool_call_id": "call-1",
|
||||
"content": "tool result",
|
||||
},
|
||||
map[string]any{"role": "user", "content": "latest user turn"},
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIHistoryTranscriptPreservesOrderAndToolHistory(t *testing.T) {
|
||||
promptMessages, historyMessages := splitOpenAIHistoryMessages(historySplitTestMessages(), 1)
|
||||
if len(promptMessages) != 2 {
|
||||
t.Fatalf("expected 2 prompt messages, got %d", len(promptMessages))
|
||||
}
|
||||
if len(historyMessages) != 3 {
|
||||
t.Fatalf("expected 3 history messages, got %d", len(historyMessages))
|
||||
}
|
||||
|
||||
transcript := buildOpenAIHistoryTranscript(historyMessages)
|
||||
if !strings.Contains(transcript, "first user turn") {
|
||||
t.Fatalf("expected user history in transcript, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "<tool_calls>") {
|
||||
t.Fatalf("expected assistant tool_calls in transcript, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "tool_call_id=call-1") {
|
||||
t.Fatalf("expected tool call id in transcript, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "[reasoning_content]") {
|
||||
t.Fatalf("expected reasoning block in HISTORY.txt, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "hidden reasoning") {
|
||||
t.Fatalf("expected reasoning text in HISTORY.txt, got %s", transcript)
|
||||
}
|
||||
|
||||
userIdx := strings.Index(transcript, "=== 1. USER ===")
|
||||
assistantIdx := strings.Index(transcript, "=== 2. ASSISTANT ===")
|
||||
toolIdx := strings.Index(transcript, "=== 3. TOOL ===")
|
||||
if userIdx < 0 || assistantIdx < 0 || toolIdx < 0 {
|
||||
t.Fatalf("expected ordered role sections, got %s", transcript)
|
||||
}
|
||||
if userIdx >= assistantIdx || assistantIdx >= toolIdx {
|
||||
t.Fatalf("expected USER -> ASSISTANT -> TOOL order, got %s", transcript)
|
||||
}
|
||||
if reasoningIdx := strings.Index(transcript, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(transcript, "<tool_calls>") {
|
||||
t.Fatalf("expected reasoning block before tool calls, got %s", transcript)
|
||||
}
|
||||
reasoning := extractHistorySplitReasoningContent(historyMessages)
|
||||
if reasoning != "hidden reasoning" {
|
||||
t.Fatalf("expected latest assistant reasoning to be extracted, got %q", reasoning)
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false)
|
||||
if !strings.Contains(finalPrompt, "latest user turn") {
|
||||
t.Fatalf("expected latest user turn in final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "first user turn") {
|
||||
t.Fatalf("expected earlier history to be removed from final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "[reasoning_content]") || !strings.Contains(finalPrompt, "hidden reasoning") {
|
||||
t.Fatalf("expected latest assistant reasoning to be attached to prompt, got %s", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "HISTORY.txt") {
|
||||
t.Fatalf("expected history instruction in final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Follow the instructions in this prompt first") {
|
||||
t.Fatalf("expected stronger prompt override in final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if strings.Index(finalPrompt, "Follow the instructions in this prompt first") > strings.Index(finalPrompt, "Continue the conversation") {
|
||||
t.Fatalf("expected history split instruction before continuity instructions, got %s", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) {
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"arguments": map[string]any{"query": "docs"},
|
||||
},
|
||||
}
|
||||
messages := []any{
|
||||
map[string]any{"role": "system", "content": "system instructions"},
|
||||
map[string]any{"role": "user", "content": "first user turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": toolCalls,
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"name": "search",
|
||||
"tool_call_id": "call-1",
|
||||
"content": "tool result",
|
||||
},
|
||||
map[string]any{"role": "user", "content": "middle user turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "middle assistant turn",
|
||||
},
|
||||
map[string]any{"role": "user", "content": "latest user turn"},
|
||||
}
|
||||
|
||||
promptMessages, historyMessages := splitOpenAIHistoryMessages(messages, 1)
|
||||
if len(promptMessages) == 0 || len(historyMessages) == 0 {
|
||||
t.Fatalf("expected both prompt and history messages, got prompt=%d history=%d", len(promptMessages), len(historyMessages))
|
||||
}
|
||||
reasoning := extractHistorySplitReasoningContent(historyMessages)
|
||||
if reasoning != "" {
|
||||
t.Fatalf("expected no reasoning in this fixture, got %q", reasoning)
|
||||
}
|
||||
|
||||
promptText, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false)
|
||||
if !strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected latest user turn in prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "middle user turn") {
|
||||
t.Fatalf("expected middle user turn to be split into history, got %s", promptText)
|
||||
}
|
||||
|
||||
historyText := buildOpenAIHistoryTranscript(historyMessages)
|
||||
if !strings.Contains(historyText, "middle user turn") {
|
||||
t.Fatalf("expected middle user turn in HISTORY.txt, got %s", historyText)
|
||||
}
|
||||
if strings.Contains(historyText, "latest user turn") {
|
||||
t.Fatalf("expected latest user turn to remain in prompt, got %s", historyText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHistorySplitSkipsFirstTurn(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
DS: ds,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
out, err := h.applyHistorySplit(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, stdReq)
|
||||
if err != nil {
|
||||
t.Fatalf("apply history split failed: %v", err)
|
||||
}
|
||||
if len(ds.uploadCalls) != 0 {
|
||||
t.Fatalf("expected no upload on first turn, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if out.FinalPrompt != stdReq.FinalPrompt {
|
||||
t.Fatalf("expected prompt unchanged on first turn")
|
||||
}
|
||||
if len(out.RefFileIDs) != len(stdReq.RefFileIDs) {
|
||||
t.Fatalf("expected ref files unchanged on first turn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: ds,
|
||||
}
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": historySplitTestMessages(),
|
||||
"stream": false,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(string(reqBody)))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
upload := ds.uploadCalls[0]
|
||||
if upload.Filename != "HISTORY.txt" {
|
||||
t.Fatalf("unexpected upload filename: %q", upload.Filename)
|
||||
}
|
||||
if upload.ContentType != "text/plain; charset=utf-8" {
|
||||
t.Fatalf("unexpected content type: %q", upload.ContentType)
|
||||
}
|
||||
if upload.Purpose != "assistants" {
|
||||
t.Fatalf("unexpected purpose: %q", upload.Purpose)
|
||||
}
|
||||
historyText := string(upload.Data)
|
||||
if !strings.Contains(historyText, "first user turn") || !strings.Contains(historyText, "tool result") {
|
||||
t.Fatalf("expected older turns in HISTORY.txt, got %s", historyText)
|
||||
}
|
||||
if strings.Contains(historyText, "latest user turn") {
|
||||
t.Fatalf("expected latest turn to remain in prompt, got %s", historyText)
|
||||
}
|
||||
if ds.completionReq == nil {
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected latest turn in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") {
|
||||
t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") {
|
||||
t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "HISTORY.txt") {
|
||||
t.Fatalf("expected history instruction in completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "Follow the instructions in this prompt first") {
|
||||
t.Fatalf("expected stronger prompt override in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Index(promptText, "Follow the instructions in this prompt first") > strings.Index(promptText, "Continue the conversation") {
|
||||
t.Fatalf("expected history split instruction before continuity instructions, got %s", promptText)
|
||||
}
|
||||
refIDs, _ := ds.completionReq["ref_file_ids"].([]any)
|
||||
if len(refIDs) == 0 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("expected uploaded history file to be first ref_file_id, got %#v", ds.completionReq["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: ds,
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": historySplitTestMessages(),
|
||||
"stream": false,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(string(reqBody)))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.completionReq == nil {
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected latest turn in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") {
|
||||
t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") {
|
||||
t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "Follow the instructions in this prompt first") {
|
||||
t.Fatalf("expected stronger prompt override in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Index(promptText, "Follow the instructions in this prompt first") > strings.Index(promptText, "Continue the conversation") {
|
||||
t.Fatalf("expected history split instruction before continuity instructions, got %s", promptText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsHistorySplitUploadFailureReturnsInternalServerError(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{uploadErr: context.DeadlineExceeded}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: ds,
|
||||
}
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": historySplitTestMessages(),
|
||||
"stream": false,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(string(reqBody)))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if ds.completionReq != nil {
|
||||
t.Fatalf("did not expect completion payload on upload failure")
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
const assistantReasoningLabel = "reasoning_content"
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
_ = traceID
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
@@ -55,17 +57,95 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
|
||||
func buildAssistantContentForPrompt(msg map[string]any) string {
|
||||
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
|
||||
switch {
|
||||
case content == "" && toolHistory == "":
|
||||
return ""
|
||||
case content == "":
|
||||
return toolHistory
|
||||
case toolHistory == "":
|
||||
return content
|
||||
default:
|
||||
return content + "\n\n" + toolHistory
|
||||
reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
|
||||
if reasoning == "" {
|
||||
reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
|
||||
}
|
||||
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
|
||||
parts := make([]string, 0, 3)
|
||||
if reasoning != "" {
|
||||
parts = append(parts, formatPromptLabeledBlock(assistantReasoningLabel, reasoning))
|
||||
}
|
||||
if content != "" {
|
||||
parts = append(parts, content)
|
||||
}
|
||||
if toolHistory != "" {
|
||||
parts = append(parts, toolHistory)
|
||||
}
|
||||
switch len(parts) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return parts[0]
|
||||
default:
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIReasoningContentForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n")
|
||||
case map[string]any:
|
||||
return extractOpenAIReasoningTextFromItem(x)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningContentFromMessage(v any) string {
|
||||
switch x := v.(type) {
|
||||
case []any:
|
||||
return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n")
|
||||
case map[string]any:
|
||||
return extractOpenAIReasoningTextFromItem(x)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningPartsFromItems(items []any) []string {
|
||||
parts := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if text := extractOpenAIReasoningTextFromItemMap(item); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningTextFromItemMap(item any) string {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return extractOpenAIReasoningTextFromItem(m)
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningTextFromItem(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(asString(m["type"]))) {
|
||||
case "reasoning", "thinking":
|
||||
for _, key := range []string{"text", "thinking", "content"} {
|
||||
if text := strings.TrimSpace(asString(m[key])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func formatPromptLabeledBlock(label, text string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
text = strings.TrimSpace(text)
|
||||
if label == "" {
|
||||
return text
|
||||
}
|
||||
return "[" + label + "]\n" + text + "\n[/" + label + "]"
|
||||
}
|
||||
|
||||
func buildToolContentForPrompt(msg map[string]any) string {
|
||||
|
||||
@@ -296,3 +296,31 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextE
|
||||
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantReasoningContentPreserved(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "visible answer",
|
||||
"reasoning_content": "internal reasoning",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized assistant message, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, "[reasoning_content]") {
|
||||
t.Fatalf("expected labeled reasoning block in assistant content, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "internal reasoning") {
|
||||
t.Fatalf("expected reasoning text in assistant content, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "visible answer") {
|
||||
t.Fatalf("expected visible answer in assistant content, got %q", content)
|
||||
}
|
||||
if reasoningIdx := strings.Index(content, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(content, "visible answer") {
|
||||
t.Fatalf("expected reasoning block before visible answer, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false)
|
||||
if !strings.Contains(finalPrompt, "Remember: Output ONLY the <tool_calls>...</tool_calls> XML block when calling tools.") {
|
||||
if !strings.Contains(finalPrompt, "Remember: The ONLY valid way to use tools is the <tool_calls> XML block at the end of your response.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final tool-call anchor instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") {
|
||||
@@ -87,3 +87,17 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
t.Fatalf("vercel prepare finalPrompt should not require fenced tool calls: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIFinalPromptWithThinkingAddsContinuationContract(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "user", "content": "继续回答上一个问题"},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, nil, "", true)
|
||||
if !strings.Contains(finalPrompt, "Continue the conversation from the full prior context") {
|
||||
t.Fatalf("expected continuation contract in thinking prompt, got=%q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "final user-facing answer only in reasoning") {
|
||||
t.Fatalf("expected visible-answer contract in thinking prompt, got=%q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +85,11 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq, err = h.applyHistorySplit(r.Context(), a, stdReq)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -112,10 +117,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -126,6 +131,9 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
|
||||
sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
|
||||
if searchEnabled {
|
||||
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
||||
}
|
||||
if writeUpstreamEmptyOutputError(w, sanitizedText, result.ContentFilter) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true, true)
|
||||
}
|
||||
|
||||
textParsed := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
@@ -224,7 +224,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
s.emitTextDelta(trimmed)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true)
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true, true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (s *responsesStreamRuntime) sendDone() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool, resetAfterToolCalls bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.emitTextDelta(evt.Content)
|
||||
@@ -56,6 +56,9 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
if resetAfterToolCalls {
|
||||
s.resetStreamToolCallState()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +152,16 @@ func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) resetStreamToolCallState() {
|
||||
s.streamToolCallIDs = map[int]string{}
|
||||
s.functionItemIDs = map[int]string{}
|
||||
s.functionOutputIDs = map[int]int{}
|
||||
s.functionArgs = map[int]string{}
|
||||
s.functionDone = map[int]bool{}
|
||||
s.functionAdded = map[int]bool{}
|
||||
s.functionNames = map[int]string{}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
|
||||
if idx, ok := s.functionOutputIDs[callIndex]; ok {
|
||||
return idx
|
||||
|
||||
@@ -12,149 +12,6 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
|
||||
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
|
||||
}
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
hasFunctionCall := false
|
||||
hasLegacyWrapper := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
if m["type"] == "tool_calls" {
|
||||
hasLegacyWrapper = true
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
|
||||
}
|
||||
if hasLegacyWrapper {
|
||||
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_item.added") {
|
||||
t.Fatalf("expected response.output_item.added event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||
}
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added")
|
||||
hasFunctionCallAdded := false
|
||||
for _, payload := range addedPayloads {
|
||||
item, _ := payload["item"].(map[string]any)
|
||||
if item == nil || asString(item["type"]) != "function_call" {
|
||||
continue
|
||||
}
|
||||
hasFunctionCallAdded = true
|
||||
if asString(item["arguments"]) != "" {
|
||||
t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"])
|
||||
}
|
||||
}
|
||||
if !hasFunctionCallAdded {
|
||||
t.Fatalf("expected function_call output_item.added payload, body=%s", body)
|
||||
}
|
||||
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
}
|
||||
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||
if doneCallID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||
}
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
var completedCallID string
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
completedCallID = strings.TrimSpace(asString(m["call_id"]))
|
||||
if completedCallID != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if completedCallID == "" {
|
||||
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
|
||||
}
|
||||
if completedCallID != doneCallID {
|
||||
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -181,51 +38,6 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -297,120 +109,54 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
|
||||
func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", "thinking...") +
|
||||
sseLine("response/content", "先读取文件。") +
|
||||
sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
streamBody := sseLine("前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>") +
|
||||
sseLine("中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) < 1 {
|
||||
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
body := rec.Body.String()
|
||||
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(doneEvents) < 2 {
|
||||
t.Fatalf("expected at least two function call done events, got %d body=%s", len(doneEvents), body)
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
ids := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
for _, payload := range doneEvents {
|
||||
callID := asString(payload["call_id"])
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "message" {
|
||||
hasMessage = true
|
||||
if _, ok := seen[callID]; ok {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
seen[callID] = struct{}{}
|
||||
ids = append(ids, callID)
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected two distinct call ids, got %#v body=%s", ids, body)
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected function_call events for tool_choice=none, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
if ids[0] == ids[1] {
|
||||
t.Fatalf("expected distinct call ids across blocks, got %#v body=%s", ids, body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,76 +194,6 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
sseLine("response/content", "plain text only") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -556,32 +232,6 @@ func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected function_call events for unknown tool, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -597,7 +247,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -624,7 +274,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -635,36 +285,6 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
output, _ := out["output"].([]any)
|
||||
foundFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
foundFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !foundFunctionCall {
|
||||
t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -676,7 +296,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -698,7 +318,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -720,7 +340,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testin
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, nil, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -757,10 +377,10 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
func extractSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
out := make([]map[string]any, 0, 2)
|
||||
out := make([]map[string]any, 0, 4)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
|
||||
@@ -35,6 +35,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: responseModel,
|
||||
Messages: messagesRaw,
|
||||
ToolsRaw: req["tools"],
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
@@ -90,6 +91,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: model,
|
||||
Messages: messagesRaw,
|
||||
ToolsRaw: req["tools"],
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
|
||||
@@ -146,53 +146,6 @@ func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
content, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
})
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
|
||||
@@ -60,7 +60,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
start := findToolSegmentStart(state, pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
@@ -74,7 +74,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
safe, hold := splitSafeContentForToolDetection(state, pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
@@ -114,14 +114,10 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
// If the captured text looks like an incomplete XML tool call block,
|
||||
// swallow it to prevent leaking raw XML tags to the client.
|
||||
if hasOpenXMLToolTag(content) {
|
||||
// Drop it silently — incomplete tool call.
|
||||
} else {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
// If capture never resolved into a real tool call, release the
|
||||
// buffered text instead of swallowing it.
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
@@ -130,100 +126,57 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
// Safety: if pending contains XML tool tag fragments (e.g. "tool_calls>"
|
||||
// from a split closing tag), swallow them instead of leaking.
|
||||
if hasOpenXMLToolTag(content) || looksLikeXMLToolTagFragment(content) {
|
||||
// Drop it — likely an incomplete tool call fragment.
|
||||
} else {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
// If pending never resolved into a real tool call, release it as text.
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
func splitSafeContentForToolDetection(state *toolStreamSieveState, s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
// If suspicious content starts at position 0, keep holding until we can
|
||||
// parse a complete tool JSON block or reach stream flush.
|
||||
return "", s
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 {
|
||||
if insideCodeFenceWithState(state, s[:xmlIdx]) {
|
||||
return s, ""
|
||||
}
|
||||
if xmlIdx > 0 {
|
||||
return s[:xmlIdx], s[xmlIdx:]
|
||||
}
|
||||
return "", s
|
||||
}
|
||||
// Also check for partial XML tool tag at end of string.
|
||||
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 && xmlIdx > start {
|
||||
start = xmlIdx
|
||||
}
|
||||
return start
|
||||
return s, ""
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
func findToolSegmentStart(state *toolStreamSieveState, s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
|
||||
bestKeyIdx := -1
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
offset := 0
|
||||
for {
|
||||
bestKeyIdx := -1
|
||||
matchedTag := ""
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower[offset:], tag)
|
||||
if idx >= 0 {
|
||||
idx += offset
|
||||
if bestKeyIdx < 0 || idx < bestKeyIdx {
|
||||
bestKeyIdx = idx
|
||||
matchedTag = tag
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if fnKeyIdx := findQuotedFunctionCallKeyStart(s); fnKeyIdx >= 0 && (bestKeyIdx < 0 || fnKeyIdx < bestKeyIdx) {
|
||||
bestKeyIdx = fnKeyIdx
|
||||
}
|
||||
// Also detect XML tool call tags.
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower, tag)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
// For XML tags, the '<' is itself the segment start.
|
||||
if bestKeyIdx < len(s) && s[bestKeyIdx] == '<' {
|
||||
if fenceStart, ok := openFenceStartBefore(s, bestKeyIdx); ok {
|
||||
return fenceStart
|
||||
if !insideCodeFenceWithState(state, s[:bestKeyIdx]) {
|
||||
return bestKeyIdx
|
||||
}
|
||||
return bestKeyIdx
|
||||
offset = bestKeyIdx + len(matchedTag)
|
||||
}
|
||||
start := strings.LastIndex(s[:bestKeyIdx], "{")
|
||||
if start < 0 {
|
||||
start = bestKeyIdx
|
||||
}
|
||||
// If the keyword matched inside an XML tag (e.g. "tool_calls" in "<tool_calls>"),
|
||||
// back up past the '<' to capture the full tag.
|
||||
if start > 0 && s[start-1] == '<' {
|
||||
start--
|
||||
}
|
||||
if fenceStart, ok := openFenceStartBefore(s, start); ok {
|
||||
return fenceStart
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -232,7 +185,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
// Try XML tool call extraction first.
|
||||
// XML tool call extraction only.
|
||||
if xmlPrefix, xmlCalls, xmlSuffix, xmlReady := consumeXMLToolCapture(captured, toolNames); xmlReady {
|
||||
return xmlPrefix, xmlCalls, xmlSuffix, true
|
||||
}
|
||||
@@ -240,45 +193,5 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if hasOpenXMLToolTag(captured) {
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := -1
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||
keyIdx = idx
|
||||
}
|
||||
}
|
||||
if fnKeyIdx := findQuotedFunctionCallKeyStart(captured); fnKeyIdx >= 0 && (keyIdx < 0 || fnKeyIdx < keyIdx) {
|
||||
keyIdx = fnKeyIdx
|
||||
}
|
||||
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
parsed := toolcall.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
// Parsed as tool-call payload but rejected by schema/policy:
|
||||
// consume it to avoid leaking raw tool_calls JSON to user content.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
// If it has obvious keywords but failed to parse even after loose repair,
|
||||
// we still might want to intercept it if it looks like an attempt at tool call.
|
||||
// For now, keep the original logic but rely on loose JSON repair.
|
||||
return captured, nil, "", true
|
||||
}
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed.Calls, suffixPart, true
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func findQuotedFunctionCallKeyStart(s string) int {
|
||||
lower := strings.ToLower(s)
|
||||
quotedIdx := findFunctionCallKeyStart(lower, `"functioncall"`)
|
||||
bareIdx := findFunctionCallKeyStart(lower, "functioncall")
|
||||
|
||||
// Prefer the quoted JSON key whenever we have a structural match.
|
||||
// Bare-key detection is only for loose payloads where the quoted form
|
||||
// is absent.
|
||||
if quotedIdx >= 0 {
|
||||
return quotedIdx
|
||||
}
|
||||
return bareIdx
|
||||
}
|
||||
|
||||
func findFunctionCallKeyStart(lower, key string) int {
|
||||
for from := 0; from < len(lower); {
|
||||
rel := strings.Index(lower[from:], key)
|
||||
if rel < 0 {
|
||||
return -1
|
||||
}
|
||||
idx := from + rel
|
||||
if isInsideJSONString(lower, idx) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
if !hasJSONObjectContextPrefix(lower[:idx]) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
if !hasJSONKeyBoundary(lower, idx, len(key)) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
j := idx + len(key)
|
||||
for j < len(lower) && (lower[j] == ' ' || lower[j] == '\t' || lower[j] == '\r' || lower[j] == '\n') {
|
||||
j++
|
||||
}
|
||||
if j < len(lower) && lower[j] == ':' {
|
||||
k := j + 1
|
||||
for k < len(lower) && (lower[k] == ' ' || lower[k] == '\t' || lower[k] == '\r' || lower[k] == '\n') {
|
||||
k++
|
||||
}
|
||||
if k < len(lower) && lower[k] != '{' {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
return idx
|
||||
}
|
||||
from = idx + 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func isInsideJSONString(s string, idx int) bool {
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < idx; i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
return inString
|
||||
}
|
||||
|
||||
func hasJSONObjectContextPrefix(prefix string) bool {
|
||||
return strings.LastIndex(prefix, "{") >= 0
|
||||
}
|
||||
|
||||
func hasJSONKeyBoundary(s string, idx, keyLen int) bool {
|
||||
if idx > 0 {
|
||||
prev := s[idx-1]
|
||||
if isLowerAlphaNumeric(prev) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if end := idx + keyLen; end < len(s) {
|
||||
next := s[end]
|
||||
if isLowerAlphaNumeric(next) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isLowerAlphaNumeric(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '_'
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierBareKey(t *testing.T) {
|
||||
input := `{functionCall:{"name":"a","arguments":"{}"},"message":"literal text: \"functionCall\": not a key"}`
|
||||
|
||||
got := findQuotedFunctionCallKeyStart(input)
|
||||
want := 1
|
||||
if got != want {
|
||||
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierQuotedKey(t *testing.T) {
|
||||
input := `{"functionCall":{"name":"a","arguments":"{}"},"note":"functionCall appears in prose"}`
|
||||
|
||||
got := findQuotedFunctionCallKeyStart(input)
|
||||
want := 1
|
||||
if got != want {
|
||||
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
@@ -2,48 +2,6 @@ package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
trimmedPrefix := strings.TrimRight(prefix, " \t\r\n")
|
||||
fenceIdx := strings.LastIndex(trimmedPrefix, "```")
|
||||
@@ -67,18 +25,3 @@ func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
consumedLeading := len(suffix) - len(trimmedSuffix)
|
||||
return trimmedPrefix[:fenceIdx], suffix[consumedLeading+3:]
|
||||
}
|
||||
|
||||
func openFenceStartBefore(s string, pos int) (int, bool) {
|
||||
if pos <= 0 || pos > len(s) {
|
||||
return -1, false
|
||||
}
|
||||
segment := s[:pos]
|
||||
lastFence := strings.LastIndex(segment, "```")
|
||||
if lastFence < 0 {
|
||||
return -1, false
|
||||
}
|
||||
if strings.Count(segment, "```")%2 == 1 {
|
||||
return lastFence, true
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
@@ -6,19 +6,21 @@ import (
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
pendingToolRaw string
|
||||
pendingToolCalls []toolcall.ParsedToolCall
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
codeFenceStack []int
|
||||
codeFencePendingTicks int
|
||||
codeFenceLineStart bool
|
||||
pendingToolRaw string
|
||||
pendingToolCalls []toolcall.ParsedToolCall
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
@@ -33,9 +35,6 @@ type toolCallDelta struct {
|
||||
Arguments string
|
||||
}
|
||||
|
||||
// Keep in sync with JS TOOL_SIEVE_CONTEXT_TAIL_LIMIT.
|
||||
const toolSieveContextTailLimit = 2048
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
s.toolNameSent = false
|
||||
@@ -47,19 +46,112 @@ func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
}
|
||||
|
||||
func (s *toolStreamSieveState) noteText(content string) {
|
||||
if content == "" {
|
||||
if !hasMeaningfulText(content) {
|
||||
return
|
||||
}
|
||||
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||
updateCodeFenceState(s, content)
|
||||
}
|
||||
|
||||
func appendTail(prev, next string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
combined := prev + next
|
||||
if len(combined) <= max {
|
||||
return combined
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
func hasMeaningfulText(text string) bool {
|
||||
return strings.TrimSpace(text) != ""
|
||||
}
|
||||
|
||||
func insideCodeFenceWithState(state *toolStreamSieveState, text string) bool {
|
||||
if state == nil {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
simulated := simulateCodeFenceState(
|
||||
state.codeFenceStack,
|
||||
state.codeFencePendingTicks,
|
||||
state.codeFenceLineStart,
|
||||
text,
|
||||
)
|
||||
return len(simulated.stack) > 0
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return len(simulateCodeFenceState(nil, 0, true, text).stack) > 0
|
||||
}
|
||||
|
||||
func updateCodeFenceState(state *toolStreamSieveState, text string) {
|
||||
if state == nil || !hasMeaningfulText(text) {
|
||||
return
|
||||
}
|
||||
next := simulateCodeFenceState(
|
||||
state.codeFenceStack,
|
||||
state.codeFencePendingTicks,
|
||||
state.codeFenceLineStart,
|
||||
text,
|
||||
)
|
||||
state.codeFenceStack = next.stack
|
||||
state.codeFencePendingTicks = next.pendingTicks
|
||||
state.codeFenceLineStart = next.lineStart
|
||||
}
|
||||
|
||||
type codeFenceSimulation struct {
|
||||
stack []int
|
||||
pendingTicks int
|
||||
lineStart bool
|
||||
}
|
||||
|
||||
func simulateCodeFenceState(stack []int, pendingTicks int, lineStart bool, text string) codeFenceSimulation {
|
||||
chunk := text
|
||||
nextStack := append([]int(nil), stack...)
|
||||
ticks := pendingTicks
|
||||
atLineStart := lineStart
|
||||
|
||||
flushTicks := func() {
|
||||
if ticks > 0 {
|
||||
if atLineStart && ticks >= 3 {
|
||||
applyFenceMarker(&nextStack, ticks)
|
||||
}
|
||||
atLineStart = false
|
||||
ticks = 0
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(chunk); i++ {
|
||||
ch := chunk[i]
|
||||
if ch == '`' {
|
||||
ticks++
|
||||
continue
|
||||
}
|
||||
flushTicks()
|
||||
switch ch {
|
||||
case '\n', '\r':
|
||||
atLineStart = true
|
||||
case ' ', '\t':
|
||||
if atLineStart {
|
||||
continue
|
||||
}
|
||||
atLineStart = false
|
||||
default:
|
||||
atLineStart = false
|
||||
}
|
||||
}
|
||||
|
||||
return codeFenceSimulation{
|
||||
stack: nextStack,
|
||||
pendingTicks: ticks,
|
||||
lineStart: atLineStart,
|
||||
}
|
||||
}
|
||||
|
||||
func applyFenceMarker(stack *[]int, ticks int) {
|
||||
if stack == nil || ticks <= 0 {
|
||||
return
|
||||
}
|
||||
if len(*stack) == 0 {
|
||||
*stack = append(*stack, ticks)
|
||||
return
|
||||
}
|
||||
top := (*stack)[len(*stack)-1]
|
||||
if ticks >= top {
|
||||
*stack = (*stack)[:len(*stack)-1]
|
||||
return
|
||||
}
|
||||
*stack = append(*stack, ticks)
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
{"<invoke", "</invoke>"},
|
||||
{"<tool_use", "</tool_use>"},
|
||||
// Agent-style: these are XML "tool call" patterns from coding agents.
|
||||
// They get captured → parsed. If parsing fails, the block is consumed
|
||||
// (swallowed) to prevent raw XML from leaking to the client.
|
||||
// They get captured → parsed. If parsing fails, the raw XML is preserved
|
||||
// so the caller can still see the original text.
|
||||
{"<attempt_completion", "</attempt_completion>"},
|
||||
{"<ask_followup_question", "</ask_followup_question>"},
|
||||
{"<new_task", "</new_task>"},
|
||||
@@ -73,31 +73,12 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
// If this block does not look like an executable tool-call payload,
|
||||
// pass it through as normal content (e.g. user-requested XML snippets).
|
||||
if !looksLikeExecutableXMLToolCallBlock(xmlBlock, pair.open) {
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
// Looks like XML tool syntax but failed to parse — consume it to avoid leak.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
// If this block failed to become a tool call, pass it through as text.
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
func looksLikeExecutableXMLToolCallBlock(xmlBlock, openTag string) bool {
|
||||
lower := strings.ToLower(xmlBlock)
|
||||
// Agent wrapper tags are always treated as internal tool-call wrappers.
|
||||
switch openTag {
|
||||
case "<attempt_completion", "<ask_followup_question", "<new_task":
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lower, "<tool_name") ||
|
||||
strings.Contains(lower, "<parameters") ||
|
||||
strings.Contains(lower, `"tool"`) ||
|
||||
strings.Contains(lower, `"tool_name"`) ||
|
||||
strings.Contains(lower, `"name"`)
|
||||
}
|
||||
|
||||
// hasOpenXMLToolTag returns true if captured text contains an XML tool opening tag
|
||||
// whose SPECIFIC closing tag has not appeared yet.
|
||||
func hasOpenXMLToolTag(captured string) bool {
|
||||
@@ -137,32 +118,3 @@ func findPartialXMLToolTagStart(s string) int {
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// looksLikeXMLToolTagFragment returns true if s looks like a fragment from a
|
||||
// split XML tool call tag — for example "tool_calls>" or "/tool_call>\n".
|
||||
// These fragments arise when '<' was consumed separately and the tail remains.
|
||||
func looksLikeXMLToolTagFragment(s string) bool {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
// Check for closing tag tails like "tool_calls>" or "/tool_calls>"
|
||||
fragments := []string{
|
||||
"tool_calls>", "tool_call>", "/tool_calls>", "/tool_call>",
|
||||
"function_calls>", "function_call>", "/function_calls>", "/function_call>",
|
||||
"invoke>", "/invoke>", "tool_use>", "/tool_use>",
|
||||
"tool_name>", "/tool_name>", "parameters>", "/parameters>",
|
||||
// Agent-style tag fragments
|
||||
"attempt_completion>", "/attempt_completion>",
|
||||
"ask_followup_question>", "/ask_followup_question>",
|
||||
"new_task>", "/new_task>",
|
||||
"result>", "/result>",
|
||||
}
|
||||
for _, f := range fragments {
|
||||
if strings.Contains(lower, f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -42,6 +42,49 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
const toolName = "write_to_file"
|
||||
payload := strings.Repeat("x", 4096)
|
||||
splitAt := len(payload) / 2
|
||||
chunks := []string{
|
||||
"<tool_calls>\n <tool_call>\n <tool_name>" + toolName + "</tool_name>\n <parameters>\n <content><![CDATA[",
|
||||
payload[:splitAt],
|
||||
payload[splitAt:],
|
||||
"]]></content>\n </parameters>\n </tool_call>\n</tool_calls>",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{toolName})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{toolName})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
var gotPayload any
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent.WriteString(evt.Content)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 && gotPayload == nil {
|
||||
gotPayload = evt.ToolCalls[0].Input["content"]
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one long XML tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.Len() != 0 {
|
||||
t.Fatalf("expected no leaked text for long XML tool call, got %q", textContent.String())
|
||||
}
|
||||
got, _ := gotPayload.(string)
|
||||
if got != payload {
|
||||
t.Fatalf("expected long XML payload to survive intact, got len=%d want=%d", len(got), len(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Model outputs some prose then an XML tool call.
|
||||
@@ -121,6 +164,105 @@ func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePassesThroughMalformedExecutableXMLBlock(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<tool_call><parameters>{"path":"README.md"}</parameters></tool_call>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected malformed executable-looking XML to stay text, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != chunk {
|
||||
t.Fatalf("expected malformed executable-looking XML to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePassesThroughFencedXMLToolCallExamples(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
input := strings.Join([]string{
|
||||
"Before first example.\n```",
|
||||
"xml\n<tool_call><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Between examples.\n```xml\n",
|
||||
"<tool_call><tool_name>search</tool_name><parameters>{\"q\":\"golang\"}</parameters></tool_call>\n",
|
||||
"```\nAfter examples.",
|
||||
}, "")
|
||||
|
||||
chunks := []string{
|
||||
"Before first example.\n```",
|
||||
"xml\n<tool_call><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Between examples.\n```xml\n",
|
||||
"<tool_call><tool_name>search</tool_name><parameters>{\"q\":\"golang\"}</parameters></tool_call>\n",
|
||||
"```\nAfter examples.",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file", "search"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file", "search"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent.WriteString(evt.Content)
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected fenced XML examples to stay text, got %d tool calls events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != input {
|
||||
t.Fatalf("expected fenced XML examples to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveKeepsPartialXMLTagInsideFencedExample(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
input := strings.Join([]string{
|
||||
"Example:\n```xml\n<tool_ca",
|
||||
"ll><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Done.",
|
||||
}, "")
|
||||
|
||||
chunks := []string{
|
||||
"Example:\n```xml\n<tool_ca",
|
||||
"ll><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Done.",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent.WriteString(evt.Content)
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected partial fenced XML to stay text, got %d tool calls events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != input {
|
||||
t.Fatalf("expected partial fenced XML to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Chunk ends with a partial XML tool tag.
|
||||
@@ -147,15 +289,16 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"gemini_function_call_json", `some text {"functionCall":{"name":"search","args":{"q":"latest"}}}`, 10},
|
||||
{"tool_call_tag", "prefix <tool_call>\n", 7},
|
||||
{"invoke_tag", "text <invoke name=\"foo\">body</invoke>", 5},
|
||||
{"xml_inside_code_fence", "```xml\n<tool_call><tool_name>read_file</tool_name></tool_call>\n```", -1},
|
||||
{"function_call_tag", "<function_call name=\"foo\">body</function_call>", 0},
|
||||
{"no_xml", "just plain text", -1},
|
||||
{"gemini_json_no_detect", `some text {"functionCall":{"name":"search"}}`, -1},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := findToolSegmentStart(tc.input)
|
||||
got := findToolSegmentStart(nil, tc.input)
|
||||
if got != tc.want {
|
||||
t.Fatalf("findToolSegmentStart(%q) = %d, want %d", tc.input, got, tc.want)
|
||||
}
|
||||
@@ -163,81 +306,6 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartIgnoresFunctionCallProse(t *testing.T) {
|
||||
input := "Please explain the functionCall API field and how clients should parse it."
|
||||
if got := findToolSegmentStart(input); got != -1 {
|
||||
t.Fatalf("expected no tool segment start for prose, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsQuotedFunctionCallKey(t *testing.T) {
|
||||
input := `prefix {"functionCall": {"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, "{")
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected JSON object start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsLooseFunctionCallKey(t *testing.T) {
|
||||
input := `prefix {functionCall: {"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, "{")
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected JSON object start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartPrefersQuotedFunctionCallOverEarlierBareProse(t *testing.T) {
|
||||
input := `prefix {note} functionCall: docs hint {"functionCall":{"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, `{"functionCall"`)
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected quoted functionCall JSON start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartIgnoresLooseFunctionCallProse(t *testing.T) {
|
||||
input := "Please explain why functionCall: is used in documentation examples."
|
||||
if got := findToolSegmentStart(input); got != -1 {
|
||||
t.Fatalf("expected no tool segment start for prose, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveDoesNotBufferFunctionCallProse(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "Please explain the functionCall API field and keep streaming this sentence."
|
||||
events := processToolSieveChunk(&state, chunk, []string{"search_web"})
|
||||
var text string
|
||||
for _, evt := range events {
|
||||
text += evt.Content
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
t.Fatalf("expected no tool calls for prose, got %#v", evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
if text != chunk {
|
||||
t.Fatalf("expected prose to pass through immediately, got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveDetectsGeminiFunctionCallPayload(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
events := processToolSieveChunk(&state, `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`, []string{"search_web"})
|
||||
events = append(events, flushToolSieve(&state, []string{"search_web"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one tool call from functionCall payload, got events=%#v", events)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(textContent), "functioncall") {
|
||||
t.Fatalf("functionCall json leaked into text content: %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -344,8 +412,8 @@ func TestProcessToolSieveTokenByTokenXMLNoLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that flushToolSieve on incomplete XML does NOT leak the raw XML content.
|
||||
func TestFlushToolSieveIncompleteXMLDoesNotLeak(t *testing.T) {
|
||||
// Test that flushToolSieve on incomplete XML falls back to raw text.
|
||||
func TestFlushToolSieveIncompleteXMLFallsBackToText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// XML block starts but stream ends before completion.
|
||||
chunks := []string{
|
||||
@@ -367,8 +435,8 @@ func TestFlushToolSieveIncompleteXMLDoesNotLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("incomplete XML leaked on flush: %q", textContent)
|
||||
if textContent != strings.Join(chunks, "") {
|
||||
t.Fatalf("expected incomplete XML to fall back to raw text, got %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,10 +473,10 @@ func TestOpeningXMLTagNotLeakedAsContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsAttemptCompletionLeak(t *testing.T) {
|
||||
func TestProcessToolSieveFallsBackToRawAttemptCompletion(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate an agent outputting attempt_completion XML tag
|
||||
// which shouldn't leak to text output, even if it fails to parse as a valid tool.
|
||||
// Simulate an agent outputting attempt_completion XML tag.
|
||||
// If it does not parse as a tool call, it should fall back to raw text.
|
||||
chunks := []string{
|
||||
"Done with task.\n",
|
||||
"<attempt_completion>\n",
|
||||
@@ -432,7 +500,7 @@ func TestProcessToolSieveInterceptsAttemptCompletionLeak(t *testing.T) {
|
||||
t.Fatalf("expected leading text to be emitted, got %q", textContent)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<attempt_completion>") || strings.Contains(textContent, "result>") {
|
||||
t.Fatalf("agent XML tag content leaked to text: %q", textContent)
|
||||
if textContent != strings.Join(chunks, "") {
|
||||
t.Fatalf("expected agent XML to fall back to raw text, got %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,14 +2,26 @@ package openai
|
||||
|
||||
import "net/http"
|
||||
|
||||
func shouldWriteUpstreamEmptyOutputError(text string) bool {
|
||||
return text == ""
|
||||
}
|
||||
|
||||
func upstreamEmptyOutputDetail(contentFilter bool, text, thinking string) (int, string, string) {
|
||||
_ = text
|
||||
if contentFilter {
|
||||
return http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter"
|
||||
}
|
||||
if thinking != "" {
|
||||
return http.StatusTooManyRequests, "Upstream model returned reasoning without visible output.", "upstream_empty_output"
|
||||
}
|
||||
return http.StatusTooManyRequests, "Upstream model returned empty output.", "upstream_empty_output"
|
||||
}
|
||||
|
||||
func writeUpstreamEmptyOutputError(w http.ResponseWriter, text string, contentFilter bool) bool {
|
||||
if text != "" {
|
||||
if !shouldWriteUpstreamEmptyOutputError(text) {
|
||||
return false
|
||||
}
|
||||
if contentFilter {
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter")
|
||||
return true
|
||||
}
|
||||
writeOpenAIErrorWithCode(w, http.StatusTooManyRequests, "Upstream model returned empty output.", "upstream_empty_output")
|
||||
status, message, code := upstreamEmptyOutputDetail(contentFilter, text, "")
|
||||
writeOpenAIErrorWithCode(w, status, message, code)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -33,6 +33,8 @@ type ConfigStore interface {
|
||||
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||
RuntimeTokenRefreshIntervalHours() int
|
||||
AutoDeleteMode() string
|
||||
HistorySplitEnabled() bool
|
||||
HistorySplitTriggerAfterTurns() int
|
||||
CompatStripReferenceMarkers() bool
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package admin
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatCaller
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatCaller
|
||||
ChatHistory *chathistory.Store
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
@@ -25,6 +28,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Post("/config/import", h.configImport)
|
||||
pr.Get("/config/export", h.configExport)
|
||||
pr.Post("/keys", h.addKey)
|
||||
pr.Put("/keys/{key}", h.updateKey)
|
||||
pr.Delete("/keys/{key}", h.deleteKey)
|
||||
pr.Get("/proxies", h.listProxies)
|
||||
pr.Post("/proxies", h.addProxy)
|
||||
@@ -33,6 +37,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Post("/proxies/test", h.testProxy)
|
||||
pr.Get("/accounts", h.listAccounts)
|
||||
pr.Post("/accounts", h.addAccount)
|
||||
pr.Put("/accounts/{identifier}", h.updateAccount)
|
||||
pr.Delete("/accounts/{identifier}", h.deleteAccount)
|
||||
pr.Put("/accounts/{identifier}/proxy", h.updateAccountProxy)
|
||||
pr.Get("/queue/status", h.queueStatus)
|
||||
@@ -50,6 +55,11 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Get("/export", h.exportConfig)
|
||||
pr.Get("/dev/captures", h.getDevCaptures)
|
||||
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||
pr.Get("/chat-history", h.getChatHistory)
|
||||
pr.Get("/chat-history/{id}", h.getChatHistoryItem)
|
||||
pr.Delete("/chat-history", h.clearChatHistory)
|
||||
pr.Delete("/chat-history/{id}", h.deleteChatHistoryItem)
|
||||
pr.Put("/chat-history/settings", h.updateChatHistorySettings)
|
||||
pr.Get("/version", h.getVersion)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
if pageSize > 5000 {
|
||||
pageSize = 5000
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
reverseAccounts(accounts)
|
||||
@@ -32,6 +32,8 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
for _, acc := range accounts {
|
||||
id := strings.ToLower(acc.Identifier())
|
||||
if strings.Contains(id, q) ||
|
||||
strings.Contains(strings.ToLower(acc.Name), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Remark), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Email), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
||||
filtered = append(filtered, acc)
|
||||
@@ -66,6 +68,8 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"name": acc.Name,
|
||||
"remark": acc.Remark,
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"proxy_id": acc.ProxyID,
|
||||
@@ -112,6 +116,46 @@ func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||
identifier = decoded
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
name, nameOK := fieldStringOptional(req, "name")
|
||||
remark, remarkOK := fieldStringOptional(req, "remark")
|
||||
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for i, acc := range c.Accounts {
|
||||
if !accountMatchesIdentifier(acc, identifier) {
|
||||
continue
|
||||
}
|
||||
if nameOK {
|
||||
c.Accounts[i].Name = name
|
||||
}
|
||||
if remarkOK {
|
||||
c.Accounts[i].Remark = remark
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return newRequestError("账号不存在")
|
||||
})
|
||||
if err != nil {
|
||||
if detail, ok := requestErrorDetail(err); ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": detail})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||
|
||||
88
internal/admin/handler_accounts_crud_test.go
Normal file
88
internal/admin/handler_accounts_crud_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestListAccountsPageSizeCapIs5000(t *testing.T) {
|
||||
accounts := make([]string, 0, 150)
|
||||
for i := range 150 {
|
||||
accounts = append(accounts, fmt.Sprintf(`{"email":"u%d@example.com","password":"pwd"}`, i))
|
||||
}
|
||||
raw := fmt.Sprintf(`{"accounts":[%s]}`, strings.Join(accounts, ","))
|
||||
router := newHTTPAdminHarness(t, raw, &testingDSMock{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodGet, "/accounts?page=1&page_size=200", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
items, _ := payload["items"].([]any)
|
||||
if len(items) != 150 {
|
||||
t.Fatalf("expected all 150 accounts with page_size=200, got %d", len(items))
|
||||
}
|
||||
if ps, _ := payload["page_size"].(float64); ps != 200 {
|
||||
t.Fatalf("expected page_size=200 in response, got %v", payload["page_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAccountsPageSizeAbove5000ClampedTo5000(t *testing.T) {
|
||||
router := newHTTPAdminHarness(t, `{"accounts":[{"email":"u@example.com","password":"pwd"}]}`, &testingDSMock{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodGet, "/accounts?page=1&page_size=9999", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if ps, _ := payload["page_size"].(float64); ps != 5000 {
|
||||
t.Fatalf("expected page_size clamped to 5000, got %v", payload["page_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountMetadataPreservesCredentials(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"email":"u@example.com","name":"old name","remark":"old remark","password":"secret"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Put("/admin/accounts/{identifier}", h.updateAccount)
|
||||
|
||||
body := []byte(`{"name":"new name","remark":"new remark"}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/accounts/u@example.com", strings.NewReader(string(body)))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Accounts) != 1 {
|
||||
t.Fatalf("unexpected accounts after update: %#v", snap.Accounts)
|
||||
}
|
||||
acc := snap.Accounts[0]
|
||||
if acc.Email != "u@example.com" {
|
||||
t.Fatalf("identifier changed unexpectedly: %#v", acc)
|
||||
}
|
||||
if acc.Name != "new name" || acc.Remark != "new remark" {
|
||||
t.Fatalf("metadata update did not persist: %#v", acc)
|
||||
}
|
||||
if acc.Password != "secret" {
|
||||
t.Fatalf("password should be preserved, got %#v", acc)
|
||||
}
|
||||
}
|
||||
134
internal/admin/handler_chat_history.go
Normal file
134
internal/admin/handler_chat_history.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
func (h *Handler) getChatHistory(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{
|
||||
"detail": err.Error(),
|
||||
"path": store.Path(),
|
||||
})
|
||||
return
|
||||
}
|
||||
etag := chathistory.ListETag(snapshot.Revision)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if strings.TrimSpace(r.Header.Get("If-None-Match")) == etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"version": snapshot.Version,
|
||||
"limit": snapshot.Limit,
|
||||
"revision": snapshot.Revision,
|
||||
"items": snapshot.Items,
|
||||
"path": store.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) getChatHistoryItem(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(chi.URLParam(r, "id"))
|
||||
if id == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "history id is required"})
|
||||
return
|
||||
}
|
||||
item, err := store.Get(id)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
etag := chathistory.DetailETag(item.ID, item.Revision)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if strings.TrimSpace(r.Header.Get("If-None-Match")) == etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"item": item,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) clearChatHistory(w http.ResponseWriter, _ *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
if err := store.Clear(); err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": err.Error(), "path": store.Path()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteChatHistoryItem(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(chi.URLParam(r, "id"))
|
||||
if id == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "history id is required"})
|
||||
return
|
||||
}
|
||||
if err := store.Delete(id); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) updateChatHistorySettings(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
snapshot, err := store.SetLimit(body.Limit)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"limit": snapshot.Limit,
|
||||
"revision": snapshot.Revision,
|
||||
"items": snapshot.Items,
|
||||
})
|
||||
}
|
||||
176
internal/admin/handler_chat_history_test.go
Normal file
176
internal/admin/handler_chat_history_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newChatHistoryAdminHarness(t *testing.T) (*Handler, *chathistory.Store) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("write config failed: %v", err)
|
||||
}
|
||||
t.Setenv("DS2API_CONFIG_PATH", configPath)
|
||||
t.Setenv("DS2API_ADMIN_KEY", "admin")
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
store, err := config.LoadStoreWithError()
|
||||
if err != nil {
|
||||
t.Fatalf("load config store failed: %v", err)
|
||||
}
|
||||
historyStore := chathistory.New(filepath.Join(dir, "chat_history.json"))
|
||||
return &Handler{Store: store, ChatHistory: historyStore}, historyStore
|
||||
}
|
||||
|
||||
func TestGetChatHistoryAndUpdateSettings(t *testing.T) {
|
||||
h, historyStore := newChatHistoryAdminHarness(t)
|
||||
entry, err := historyStore.Start(chathistory.StartParams{
|
||||
CallerID: "caller:test",
|
||||
AccountID: "user@example.com",
|
||||
Model: "deepseek-chat",
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start history failed: %v", err)
|
||||
}
|
||||
if _, err := historyStore.Update(entry.ID, chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
Content: "world",
|
||||
Completed: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("update history failed: %v", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/chat-history", nil)
|
||||
req.Header.Set("Authorization", "Bearer admin")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode payload failed: %v", err)
|
||||
}
|
||||
items, _ := payload["items"].([]any)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected one history item, got %#v", payload)
|
||||
}
|
||||
if rec.Header().Get("ETag") == "" {
|
||||
t.Fatalf("expected list etag header")
|
||||
}
|
||||
|
||||
notModifiedReq := httptest.NewRequest(http.MethodGet, "/chat-history", nil)
|
||||
notModifiedReq.Header.Set("Authorization", "Bearer admin")
|
||||
notModifiedReq.Header.Set("If-None-Match", rec.Header().Get("ETag"))
|
||||
notModifiedRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(notModifiedRec, notModifiedReq)
|
||||
if notModifiedRec.Code != http.StatusNotModified {
|
||||
t.Fatalf("expected 304, got %d body=%s", notModifiedRec.Code, notModifiedRec.Body.String())
|
||||
}
|
||||
|
||||
itemReq := httptest.NewRequest(http.MethodGet, "/chat-history/"+entry.ID, nil)
|
||||
itemReq.Header.Set("Authorization", "Bearer admin")
|
||||
itemRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(itemRec, itemReq)
|
||||
if itemRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected item 200, got %d body=%s", itemRec.Code, itemRec.Body.String())
|
||||
}
|
||||
if itemRec.Header().Get("ETag") == "" {
|
||||
t.Fatalf("expected detail etag header")
|
||||
}
|
||||
|
||||
updateReq := httptest.NewRequest(http.MethodPut, "/chat-history/settings", bytes.NewReader([]byte(`{"limit":10}`)))
|
||||
updateReq.Header.Set("Authorization", "Bearer admin")
|
||||
updateRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(updateRec, updateReq)
|
||||
if updateRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 from settings update, got %d body=%s", updateRec.Code, updateRec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != 10 {
|
||||
t.Fatalf("expected limit=10, got %d", snapshot.Limit)
|
||||
}
|
||||
|
||||
disableReq := httptest.NewRequest(http.MethodPut, "/chat-history/settings", bytes.NewReader([]byte(`{"limit":0}`)))
|
||||
disableReq.Header.Set("Authorization", "Bearer admin")
|
||||
disableRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(disableRec, disableReq)
|
||||
if disableRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 from disable update, got %d body=%s", disableRec.Code, disableRec.Body.String())
|
||||
}
|
||||
snapshot, err = historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot after disable failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != chathistory.DisabledLimit {
|
||||
t.Fatalf("expected limit=0, got %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected history preserved when disabled, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAndClearChatHistory(t *testing.T) {
|
||||
h, historyStore := newChatHistoryAdminHarness(t)
|
||||
entryA, err := historyStore.Start(chathistory.StartParams{UserInput: "a"})
|
||||
if err != nil {
|
||||
t.Fatalf("start A failed: %v", err)
|
||||
}
|
||||
if _, err := historyStore.Start(chathistory.StartParams{UserInput: "b"}); err != nil {
|
||||
t.Fatalf("start B failed: %v", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/chat-history/"+entryA.ID, nil)
|
||||
deleteReq.Header.Set("Authorization", "Bearer admin")
|
||||
deleteRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(deleteRec, deleteReq)
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete 200, got %d body=%s", deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one item after delete, got %d", len(snapshot.Items))
|
||||
}
|
||||
|
||||
clearReq := httptest.NewRequest(http.MethodDelete, "/chat-history", nil)
|
||||
clearReq.Header.Set("Authorization", "Bearer admin")
|
||||
clearRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(clearRec, clearReq)
|
||||
if clearRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected clear 200, got %d body=%s", clearRec.Code, clearRec.Body.String())
|
||||
}
|
||||
|
||||
snapshot, err = historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected empty items after clear, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
@@ -53,25 +53,12 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
next.Accounts = normalizeAndDedupeAccounts(next.Accounts)
|
||||
next.VercelSyncHash = c.VercelSyncHash
|
||||
next.VercelSyncTime = c.VercelSyncTime
|
||||
importedKeys = len(next.Keys)
|
||||
importedKeys = len(next.APIKeys)
|
||||
importedAccounts = len(next.Accounts)
|
||||
} else {
|
||||
existingKeys := map[string]struct{}{}
|
||||
for _, k := range next.Keys {
|
||||
existingKeys[k] = struct{}{}
|
||||
}
|
||||
for _, k := range incoming.Keys {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
existingKeys[key] = struct{}{}
|
||||
next.Keys = append(next.Keys, key)
|
||||
importedKeys++
|
||||
}
|
||||
var changed int
|
||||
next.APIKeys, changed = mergeAPIKeysPreferStructured(next.APIKeys, incoming.APIKeys)
|
||||
importedKeys += changed
|
||||
|
||||
existingAccounts := map[string]struct{}{}
|
||||
for _, acc := range next.Accounts {
|
||||
|
||||
@@ -11,6 +11,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"api_keys": snap.APIKeys,
|
||||
"accounts": []map[string]any{},
|
||||
"proxies": []map[string]any{},
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
@@ -37,6 +38,8 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
accounts = append(accounts, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"name": acc.Name,
|
||||
"remark": acc.Remark,
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"proxy_id": acc.ProxyID,
|
||||
|
||||
@@ -19,7 +19,9 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
old := h.Store.Snapshot()
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
if apiKeys, ok := toAPIKeys(req["api_keys"]); ok {
|
||||
c.APIKeys = apiKeys
|
||||
} else if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
c.Keys = keys
|
||||
}
|
||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||
@@ -78,17 +80,19 @@ func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key, _ := req["key"].(string)
|
||||
key = strings.TrimSpace(key)
|
||||
name := fieldString(req, "name")
|
||||
remark := fieldString(req, "remark")
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, k := range c.Keys {
|
||||
if k == key {
|
||||
for _, item := range c.APIKeys {
|
||||
if item.Key == key {
|
||||
return fmt.Errorf("key 已存在")
|
||||
}
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
c.APIKeys = append(c.APIKeys, config.APIKey{Key: key, Name: name, Remark: remark})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -98,12 +102,25 @@ func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
func (h *Handler) updateKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := strings.TrimSpace(chi.URLParam(r, "key"))
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "key 不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
name, nameOK := fieldStringOptional(req, "name")
|
||||
remark, remarkOK := fieldStringOptional(req, "remark")
|
||||
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, k := range c.Keys {
|
||||
if k == key {
|
||||
for i, item := range c.APIKeys {
|
||||
if item.Key == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
@@ -111,7 +128,35 @@ func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("key 不存在")
|
||||
}
|
||||
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
|
||||
if nameOK {
|
||||
c.APIKeys[idx].Name = name
|
||||
}
|
||||
if remarkOK {
|
||||
c.APIKeys[idx].Remark = remark
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, item := range c.APIKeys {
|
||||
if item.Key == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("key 不存在")
|
||||
}
|
||||
c.APIKeys = append(c.APIKeys[:idx], c.APIKeys[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -129,20 +174,23 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if apiKeys, ok := toAPIKeys(req["api_keys"]); ok {
|
||||
var changed int
|
||||
c.APIKeys, changed = mergeAPIKeysPreferStructured(c.APIKeys, apiKeys)
|
||||
importedKeys += changed
|
||||
}
|
||||
if keys, ok := req["keys"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, k := range c.Keys {
|
||||
existing[k] = true
|
||||
}
|
||||
legacy := make([]config.APIKey, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", k))
|
||||
if key == "" || existing[key] {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
existing[key] = true
|
||||
importedKeys++
|
||||
legacy = append(legacy, config.APIKey{Key: key})
|
||||
}
|
||||
var changed int
|
||||
c.APIKeys, changed = mergeAPIKeysPreferStructured(c.APIKeys, legacy)
|
||||
importedKeys += changed
|
||||
}
|
||||
if accounts, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
|
||||
76
internal/admin/handler_keys_test.go
Normal file
76
internal/admin/handler_keys_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestKeyEndpointsPreserveStructuredMetadata(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Post("/admin/keys", h.addKey)
|
||||
r.Put("/admin/keys/{key}", h.updateKey)
|
||||
r.Delete("/admin/keys/{key}", h.deleteKey)
|
||||
|
||||
addBody := []byte(`{"key":"k2","name":"secondary","remark":"staging"}`)
|
||||
addReq := httptest.NewRequest(http.MethodPost, "/admin/keys", bytes.NewReader(addBody))
|
||||
addRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(addRec, addReq)
|
||||
if addRec.Code != http.StatusOK {
|
||||
t.Fatalf("add status=%d body=%s", addRec.Code, addRec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after add: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("existing metadata was lost after add: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("new metadata was lost after add: %#v", snap.APIKeys[1])
|
||||
}
|
||||
|
||||
updateBody := map[string]any{
|
||||
"name": "primary-updated",
|
||||
"remark": "prod-updated",
|
||||
}
|
||||
updateBytes, _ := json.Marshal(updateBody)
|
||||
updateReq := httptest.NewRequest(http.MethodPut, "/admin/keys/k1", bytes.NewReader(updateBytes))
|
||||
updateRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(updateRec, updateReq)
|
||||
if updateRec.Code != http.StatusOK {
|
||||
t.Fatalf("update status=%d body=%s", updateRec.Code, updateRec.Body.String())
|
||||
}
|
||||
|
||||
snap = h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Key != "k1" || snap.APIKeys[0].Name != "primary-updated" || snap.APIKeys[0].Remark != "prod-updated" {
|
||||
t.Fatalf("metadata update did not persist: %#v", snap.APIKeys[0])
|
||||
}
|
||||
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/admin/keys/k1", nil)
|
||||
deleteRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(deleteRec, deleteReq)
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("delete status=%d body=%s", deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
|
||||
snap = h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 1 || snap.APIKeys[0].Key != "k2" {
|
||||
t.Fatalf("unexpected api keys after delete: %#v", snap.APIKeys)
|
||||
}
|
||||
if len(snap.Keys) != 1 || snap.Keys[0] != "k2" {
|
||||
t.Fatalf("unexpected legacy keys after delete: %#v", snap.Keys)
|
||||
}
|
||||
}
|
||||
@@ -21,16 +21,17 @@ func boolFrom(v any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.CompatConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, map[string]string, map[string]string, error) {
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.CompatConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, *config.HistorySplitConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
compatCfg *config.CompatConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
autoDeleteCfg *config.AutoDeleteConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
compatCfg *config.CompatConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
autoDeleteCfg *config.AutoDeleteConfig
|
||||
historySplitCfg *config.HistorySplitConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
@@ -38,7 +39,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("admin.jwt_expire_hours", n, 1, 720, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.JWTExpireHours = n
|
||||
}
|
||||
@@ -50,33 +51,33 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["account_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.account_max_inflight", n, 1, 256, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.AccountMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["account_max_queue"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.account_max_queue", n, 1, 200000, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.AccountMaxQueue = n
|
||||
}
|
||||
if v, exists := raw["global_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.global_max_inflight", n, 1, 200000, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["token_refresh_interval_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.token_refresh_interval_hours", n, 1, 720, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.TokenRefreshIntervalHours = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("responses.store_ttl_seconds", n, 30, 86400, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.StoreTTLSeconds = n
|
||||
}
|
||||
@@ -111,7 +112,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if err := config.ValidateTrimmedString("embeddings.provider", p, false); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
@@ -147,7 +148,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["mode"]; exists {
|
||||
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
if err := config.ValidateAutoDeleteMode(mode); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "none"
|
||||
@@ -160,5 +161,24 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
autoDeleteCfg = cfg
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, compatCfg, respCfg, embCfg, autoDeleteCfg, claudeMap, aliasMap, nil
|
||||
if raw, ok := req["history_split"].(map[string]any); ok {
|
||||
cfg := &config.HistorySplitConfig{}
|
||||
if v, exists := raw["enabled"]; exists {
|
||||
b := boolFrom(v)
|
||||
cfg.Enabled = &b
|
||||
}
|
||||
if v, exists := raw["trigger_after_turns"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("history_split.trigger_after_turns", n, 1, 1000, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.TriggerAfterTurns = &n
|
||||
}
|
||||
if err := config.ValidateHistorySplitConfig(*cfg); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
historySplitCfg = cfg
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, compatCfg, respCfg, embCfg, autoDeleteCfg, historySplitCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
|
||||
@@ -26,10 +26,14 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
"token_refresh_interval_hours": h.Store.RuntimeTokenRefreshIntervalHours(),
|
||||
},
|
||||
"compat": snap.Compat,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"auto_delete": snap.AutoDelete,
|
||||
"compat": snap.Compat,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"auto_delete": snap.AutoDelete,
|
||||
"history_split": map[string]any{
|
||||
"enabled": h.Store.HistorySplitEnabled(),
|
||||
"trigger_after_turns": h.Store.HistorySplitTriggerAfterTurns(),
|
||||
},
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
|
||||
@@ -47,6 +47,25 @@ func TestGetSettingsIncludesTokenRefreshInterval(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettingsIncludesHistorySplitDefaults(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.getSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
_ = json.Unmarshal(rec.Body.Bytes(), &body)
|
||||
historySplit, _ := body["history_split"].(map[string]any)
|
||||
if got := boolFrom(historySplit["enabled"]); !got {
|
||||
t.Fatalf("expected history_split.enabled=true, body=%v", body)
|
||||
}
|
||||
if got := intFrom(historySplit["trigger_after_turns"]); got != 1 {
|
||||
t.Fatalf("expected history_split.trigger_after_turns=1, got %d body=%v", got, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidation(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
@@ -154,6 +173,30 @@ func TestUpdateSettingsWithoutRuntimeSkipsMergedRuntimeValidation(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsHistorySplit(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
"history_split": map[string]any{
|
||||
"enabled": false,
|
||||
"trigger_after_turns": 3,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snap := h.Store.Snapshot()
|
||||
if snap.HistorySplit.Enabled == nil || *snap.HistorySplit.Enabled {
|
||||
t.Fatalf("expected history_split.enabled=false, got %#v", snap.HistorySplit.Enabled)
|
||||
}
|
||||
if snap.HistorySplit.TriggerAfterTurns == nil || *snap.HistorySplit.TriggerAfterTurns != 3 {
|
||||
t.Fatalf("expected history_split.trigger_after_turns=3, got %#v", snap.HistorySplit.TriggerAfterTurns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsAutoDeleteMode(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"],"auto_delete":{"sessions":true}}`)
|
||||
|
||||
@@ -234,6 +277,75 @@ func TestUpdateSettingsHotReloadTokenRefreshInterval(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigPreservesStructuredAPIKeysWhenBothFieldsPresent(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["legacy"],
|
||||
"api_keys":[{"key":"legacy","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"keys": []any{"legacy", "new-key"},
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "legacy", "name": "primary-updated", "remark": "prod-updated"},
|
||||
map[string]any{"key": "new-key", "name": "secondary", "remark": "staging"},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateConfig(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after config update: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after config update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary-updated" || snap.APIKeys[0].Remark != "prod-updated" {
|
||||
t.Fatalf("structured metadata for existing key was not preserved: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("structured metadata for new key was not preserved: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigLegacyKeysPreserveStructuredMetadata(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"api_keys":[{"key":"legacy","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"keys": []any{"legacy", "new-key"},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateConfig(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after legacy config update: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after legacy config update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("existing structured metadata was lost: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Key != "new-key" || snap.APIKeys[1].Name != "" || snap.APIKeys[1].Remark != "" {
|
||||
t.Fatalf("new legacy key should remain metadata-free: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) {
|
||||
hash := authn.HashAdminPassword("old-password")
|
||||
h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`)
|
||||
@@ -315,6 +427,113 @@ func TestConfigImportMergeAndReplace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportMergePreservesStructuredAPIKeys(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}]
|
||||
}`)
|
||||
|
||||
merge := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "k1", "name": "should-not-overwrite", "remark": "ignored"},
|
||||
map[string]any{"key": "k2", "name": "secondary", "remark": "staging"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mergeBytes, _ := json.Marshal(merge)
|
||||
mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes))
|
||||
mergeRec := httptest.NewRecorder()
|
||||
h.configImport(mergeRec, mergeReq)
|
||||
if mergeRec.Code != http.StatusOK {
|
||||
t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after structured merge: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("existing structured metadata was overwritten: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("new structured metadata was lost: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportMergeUpgradesLegacyAPIKeys(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["legacy"],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
merge := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "legacy", "name": "primary", "remark": "prod"},
|
||||
map[string]any{"key": "new-key", "name": "secondary", "remark": "staging"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mergeBytes, _ := json.Marshal(merge)
|
||||
mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes))
|
||||
mergeRec := httptest.NewRecorder()
|
||||
h.configImport(mergeRec, mergeReq)
|
||||
if mergeRec.Code != http.StatusOK {
|
||||
t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after legacy import merge: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after legacy import merge: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("legacy key metadata was not upgraded: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("new structured metadata was not preserved: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchImportUpgradesLegacyAPIKeys(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["legacy"],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"keys": []any{"legacy", "new-key"},
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "legacy", "name": "primary", "remark": "prod"},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/import", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.batchImport(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after batch import: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after batch import: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("legacy key metadata was not upgraded: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "" || snap.APIKeys[1].Remark != "" {
|
||||
t.Fatalf("new batch-imported key should stay metadata-free: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportAppliesTokenRefreshInterval(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, compatCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
adminCfg, runtimeCfg, compatCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, historySplitCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
@@ -67,6 +67,14 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
c.AutoDelete.Mode = autoDeleteCfg.Mode
|
||||
c.AutoDelete.Sessions = autoDeleteCfg.Sessions
|
||||
}
|
||||
if historySplitCfg != nil {
|
||||
if historySplitCfg.Enabled != nil {
|
||||
c.HistorySplit.Enabled = historySplitCfg.Enabled
|
||||
}
|
||||
if historySplitCfg.TriggerAfterTurns != nil {
|
||||
c.HistorySplit.TriggerAfterTurns = historySplitCfg.TriggerAfterTurns
|
||||
}
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
|
||||
@@ -62,6 +62,8 @@ func toAccount(m map[string]any) config.Account {
|
||||
email := fieldString(m, "email")
|
||||
mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile"))
|
||||
return config.Account{
|
||||
Name: fieldString(m, "name"),
|
||||
Remark: fieldString(m, "remark"),
|
||||
Email: email,
|
||||
Mobile: mobile,
|
||||
Password: fieldString(m, "password"),
|
||||
@@ -69,6 +71,116 @@ func toAccount(m map[string]any) config.Account {
|
||||
}
|
||||
}
|
||||
|
||||
func toAPIKeys(v any) ([]config.APIKey, bool) {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]config.APIKey, 0, len(arr))
|
||||
seen := map[string]struct{}{}
|
||||
for _, item := range arr {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
key := fieldString(x, "key")
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, config.APIKey{
|
||||
Key: key,
|
||||
Name: fieldString(x, "name"),
|
||||
Remark: fieldString(x, "remark"),
|
||||
})
|
||||
default:
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", item))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, config.APIKey{Key: key})
|
||||
}
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func normalizeAPIKeyForStorage(item config.APIKey) config.APIKey {
|
||||
return config.APIKey{
|
||||
Key: strings.TrimSpace(item.Key),
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
}
|
||||
}
|
||||
|
||||
func apiKeyHasMetadata(item config.APIKey) bool {
|
||||
return strings.TrimSpace(item.Name) != "" || strings.TrimSpace(item.Remark) != ""
|
||||
}
|
||||
|
||||
func mergeAPIKeysPreferStructured(existing, incoming []config.APIKey) ([]config.APIKey, int) {
|
||||
if len(existing) == 0 && len(incoming) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
merged := make([]config.APIKey, 0, len(existing)+len(incoming))
|
||||
index := make(map[string]int, len(existing)+len(incoming))
|
||||
for _, item := range existing {
|
||||
item = normalizeAPIKeyForStorage(item)
|
||||
if item.Key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := index[item.Key]; ok {
|
||||
continue
|
||||
}
|
||||
index[item.Key] = len(merged)
|
||||
merged = append(merged, item)
|
||||
}
|
||||
|
||||
imported := 0
|
||||
for _, item := range incoming {
|
||||
item = normalizeAPIKeyForStorage(item)
|
||||
if item.Key == "" {
|
||||
continue
|
||||
}
|
||||
if idx, ok := index[item.Key]; ok {
|
||||
keep := merged[idx]
|
||||
next := mergeAPIKeyRecord(keep, item)
|
||||
if next != keep {
|
||||
merged[idx] = next
|
||||
imported++
|
||||
}
|
||||
continue
|
||||
}
|
||||
index[item.Key] = len(merged)
|
||||
merged = append(merged, item)
|
||||
imported++
|
||||
}
|
||||
|
||||
if len(merged) == 0 {
|
||||
return nil, imported
|
||||
}
|
||||
return merged, imported
|
||||
}
|
||||
|
||||
func mergeAPIKeyRecord(existing, incoming config.APIKey) config.APIKey {
|
||||
existing = normalizeAPIKeyForStorage(existing)
|
||||
incoming = normalizeAPIKeyForStorage(incoming)
|
||||
if existing.Key == "" {
|
||||
return incoming
|
||||
}
|
||||
if apiKeyHasMetadata(existing) {
|
||||
return existing
|
||||
}
|
||||
if apiKeyHasMetadata(incoming) {
|
||||
return incoming
|
||||
}
|
||||
return existing
|
||||
}
|
||||
|
||||
func fieldString(m map[string]any, key string) string {
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
@@ -77,6 +189,14 @@ func fieldString(m map[string]any, key string) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
func fieldStringOptional(m map[string]any, key string) (string, bool) {
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return "", false
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v)), true
|
||||
}
|
||||
|
||||
func statusOr(v int, d int) int {
|
||||
if v == 0 {
|
||||
return d
|
||||
@@ -99,6 +219,8 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
||||
}
|
||||
|
||||
func normalizeAccountForStorage(acc config.Account) config.Account {
|
||||
acc.Name = strings.TrimSpace(acc.Name)
|
||||
acc.Remark = strings.TrimSpace(acc.Remark)
|
||||
acc.Email = strings.TrimSpace(acc.Email)
|
||||
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
|
||||
acc.ProxyID = strings.TrimSpace(acc.ProxyID)
|
||||
|
||||
766
internal/chathistory/store.go
Normal file
766
internal/chathistory/store.go
Normal file
@@ -0,0 +1,766 @@
|
||||
package chathistory
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
FileVersion = 2
|
||||
DisabledLimit = 0
|
||||
DefaultLimit = 20
|
||||
MaxLimit = 50
|
||||
defaultPreviewAt = 160
|
||||
)
|
||||
|
||||
var allowedLimits = map[int]struct{}{
|
||||
DisabledLimit: {},
|
||||
10: {},
|
||||
20: {},
|
||||
50: {},
|
||||
}
|
||||
|
||||
var ErrDisabled = errors.New("chat history disabled")
|
||||
|
||||
type Entry struct {
|
||||
ID string `json:"id"`
|
||||
Revision int64 `json:"revision"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CallerID string `json:"caller_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserInput string `json:"user_input,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
FinalPrompt string `json:"final_prompt,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage map[string]any `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type SummaryEntry struct {
|
||||
ID string `json:"id"`
|
||||
Revision int64 `json:"revision"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CallerID string `json:"caller_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserInput string `json:"user_input,omitempty"`
|
||||
Preview string `json:"preview,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
DetailRevision int64 `json:"detail_revision"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Version int `json:"version"`
|
||||
Limit int `json:"limit"`
|
||||
Revision int64 `json:"revision"`
|
||||
Items []SummaryEntry `json:"items"`
|
||||
}
|
||||
|
||||
type StartParams struct {
|
||||
CallerID string
|
||||
AccountID string
|
||||
Model string
|
||||
Stream bool
|
||||
UserInput string
|
||||
Messages []Message
|
||||
FinalPrompt string
|
||||
}
|
||||
|
||||
type UpdateParams struct {
|
||||
Status string
|
||||
ReasoningContent string
|
||||
Content string
|
||||
Error string
|
||||
StatusCode int
|
||||
ElapsedMs int64
|
||||
FinishReason string
|
||||
Usage map[string]any
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type detailEnvelope struct {
|
||||
Version int `json:"version"`
|
||||
Item Entry `json:"item"`
|
||||
}
|
||||
|
||||
type legacyFile struct {
|
||||
Version int `json:"version"`
|
||||
Limit int `json:"limit"`
|
||||
Items []Entry `json:"items"`
|
||||
}
|
||||
|
||||
type legacyProbe struct {
|
||||
Items []map[string]json.RawMessage `json:"items"`
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
detailDir string
|
||||
state File
|
||||
details map[string]Entry
|
||||
dirty map[string]struct{}
|
||||
deleted map[string]struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func New(path string) *Store {
|
||||
s := &Store{
|
||||
path: strings.TrimSpace(path),
|
||||
detailDir: strings.TrimSpace(path) + ".d",
|
||||
state: File{
|
||||
Version: FileVersion,
|
||||
Limit: DefaultLimit,
|
||||
Revision: 0,
|
||||
Items: []SummaryEntry{},
|
||||
},
|
||||
details: map[string]Entry{},
|
||||
dirty: map[string]struct{}{},
|
||||
deleted: map[string]struct{}{},
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.err = s.loadLocked()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Store) Path() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.path
|
||||
}
|
||||
|
||||
func (s *Store) DetailDir() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.detailDir
|
||||
}
|
||||
|
||||
func (s *Store) Err() error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() (File, error) {
|
||||
if s == nil {
|
||||
return File{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return File{}, s.err
|
||||
}
|
||||
return cloneFile(s.state), nil
|
||||
}
|
||||
|
||||
func (s *Store) Enabled() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
return s.state.Limit != DisabledLimit
|
||||
}
|
||||
|
||||
func (s *Store) Get(id string) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
item, ok := s.details[strings.TrimSpace(id)]
|
||||
if !ok {
|
||||
return Entry{}, errors.New("chat history entry not found")
|
||||
}
|
||||
return cloneEntry(item), nil
|
||||
}
|
||||
|
||||
func (s *Store) Start(params StartParams) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
if s.state.Limit == DisabledLimit {
|
||||
return Entry{}, ErrDisabled
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
revision := s.nextRevisionLocked()
|
||||
entry := Entry{
|
||||
ID: "chat_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
Revision: revision,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Status: "streaming",
|
||||
CallerID: strings.TrimSpace(params.CallerID),
|
||||
AccountID: strings.TrimSpace(params.AccountID),
|
||||
Model: strings.TrimSpace(params.Model),
|
||||
Stream: params.Stream,
|
||||
UserInput: strings.TrimSpace(params.UserInput),
|
||||
Messages: cloneMessages(params.Messages),
|
||||
FinalPrompt: strings.TrimSpace(params.FinalPrompt),
|
||||
}
|
||||
s.details[entry.ID] = entry
|
||||
s.markDetailDirtyLocked(entry.ID)
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return cloneEntry(entry), err
|
||||
}
|
||||
return cloneEntry(entry), nil
|
||||
}
|
||||
|
||||
func (s *Store) Update(id string, params UpdateParams) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
target := strings.TrimSpace(id)
|
||||
if target == "" {
|
||||
return Entry{}, errors.New("history id is required")
|
||||
}
|
||||
item, ok := s.details[target]
|
||||
if !ok {
|
||||
return Entry{}, errors.New("chat history entry not found")
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
item.Revision = s.nextRevisionLocked()
|
||||
item.UpdatedAt = now
|
||||
if params.Status != "" {
|
||||
item.Status = params.Status
|
||||
}
|
||||
item.ReasoningContent = params.ReasoningContent
|
||||
item.Content = params.Content
|
||||
item.Error = strings.TrimSpace(params.Error)
|
||||
item.StatusCode = params.StatusCode
|
||||
item.ElapsedMs = params.ElapsedMs
|
||||
item.FinishReason = strings.TrimSpace(params.FinishReason)
|
||||
if params.Usage != nil {
|
||||
item.Usage = cloneMap(params.Usage)
|
||||
}
|
||||
if params.Completed {
|
||||
item.CompletedAt = now
|
||||
}
|
||||
s.details[target] = item
|
||||
s.markDetailDirtyLocked(target)
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
return cloneEntry(item), nil
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id string) error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
target := strings.TrimSpace(id)
|
||||
if target == "" {
|
||||
return errors.New("history id is required")
|
||||
}
|
||||
if _, ok := s.details[target]; !ok {
|
||||
return errors.New("chat history entry not found")
|
||||
}
|
||||
s.markDetailDeletedLocked(target)
|
||||
delete(s.details, target)
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Clear() error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
for id := range s.details {
|
||||
s.markDetailDeletedLocked(id)
|
||||
}
|
||||
s.details = map[string]Entry{}
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) SetLimit(limit int) (File, error) {
|
||||
if s == nil {
|
||||
return File{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return File{}, s.err
|
||||
}
|
||||
if !isAllowedLimit(limit) {
|
||||
return File{}, fmt.Errorf("unsupported chat history limit: %d", limit)
|
||||
}
|
||||
s.state.Limit = limit
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return cloneFile(s.state), nil
|
||||
}
|
||||
|
||||
func (s *Store) loadLocked() error {
|
||||
if strings.TrimSpace(s.path) == "" {
|
||||
return errors.New("chat history path is required")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil && filepath.Dir(s.path) != "." {
|
||||
return fmt.Errorf("create chat history dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(s.detailDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history detail dir: %w", err)
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if saveErr := s.saveLocked(); saveErr != nil {
|
||||
config.Logger.Warn("[chat_history] bootstrap write failed", "path", s.path, "error", saveErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read chat history index: %w", err)
|
||||
}
|
||||
|
||||
legacy, legacyOK, legacyErr := parseLegacy(raw)
|
||||
if legacyErr != nil {
|
||||
return legacyErr
|
||||
}
|
||||
if legacyOK {
|
||||
s.loadLegacyLocked(legacy)
|
||||
if err := s.saveLocked(); err != nil {
|
||||
config.Logger.Warn("[chat_history] legacy migration writeback failed", "path", s.path, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var state File
|
||||
if err := json.Unmarshal(raw, &state); err != nil {
|
||||
return fmt.Errorf("decode chat history index: %w", err)
|
||||
}
|
||||
if state.Version == 0 {
|
||||
state.Version = FileVersion
|
||||
}
|
||||
if !isAllowedLimit(state.Limit) {
|
||||
state.Limit = DefaultLimit
|
||||
}
|
||||
s.state = cloneFile(state)
|
||||
s.details = map[string]Entry{}
|
||||
for _, item := range state.Items {
|
||||
detail, err := readDetailFile(filepath.Join(s.detailDir, item.ID+".json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.details[item.ID] = detail
|
||||
}
|
||||
s.rebuildIndexLocked()
|
||||
if saveErr := s.saveLocked(); saveErr != nil {
|
||||
config.Logger.Warn("[chat_history] index rewrite failed", "path", s.path, "error", saveErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) loadLegacyLocked(legacy legacyFile) {
|
||||
s.state.Version = FileVersion
|
||||
s.state.Limit = legacy.Limit
|
||||
if !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
s.details = map[string]Entry{}
|
||||
s.dirty = map[string]struct{}{}
|
||||
s.deleted = map[string]struct{}{}
|
||||
maxRevision := int64(0)
|
||||
for _, item := range legacy.Items {
|
||||
if strings.TrimSpace(item.ID) == "" {
|
||||
continue
|
||||
}
|
||||
item.Messages = cloneMessages(item.Messages)
|
||||
if item.Revision == 0 {
|
||||
if item.UpdatedAt > 0 {
|
||||
item.Revision = item.UpdatedAt
|
||||
} else {
|
||||
item.Revision = time.Now().UnixNano()
|
||||
}
|
||||
}
|
||||
if item.Revision > maxRevision {
|
||||
maxRevision = item.Revision
|
||||
}
|
||||
s.details[item.ID] = item
|
||||
s.markDetailDirtyLocked(item.ID)
|
||||
}
|
||||
s.state.Revision = maxRevision
|
||||
s.rebuildIndexLocked()
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
s.state.Version = FileVersion
|
||||
if !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
s.rebuildIndexLocked()
|
||||
|
||||
if err := os.MkdirAll(s.detailDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history detail dir: %w", err)
|
||||
}
|
||||
for _, id := range sortedDetailIDs(s.deleted) {
|
||||
path := filepath.Join(s.detailDir, id+".json")
|
||||
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("remove stale chat history detail: %w", err)
|
||||
}
|
||||
}
|
||||
for _, id := range sortedDetailIDs(s.dirty) {
|
||||
item, ok := s.details[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(s.detailDir, id+".json")
|
||||
payload, err := json.MarshalIndent(detailEnvelope{
|
||||
Version: FileVersion,
|
||||
Item: item,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chat history detail: %w", err)
|
||||
}
|
||||
if err := writeFileAtomic(path, append(payload, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.MarshalIndent(s.state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chat history index: %w", err)
|
||||
}
|
||||
if err := writeFileAtomic(s.path, append(payload, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
s.clearPendingDetailChangesLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) rebuildIndexLocked() {
|
||||
summaries := make([]SummaryEntry, 0, len(s.details))
|
||||
for _, item := range s.details {
|
||||
summaries = append(summaries, summaryFromEntry(item))
|
||||
}
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
if summaries[i].UpdatedAt == summaries[j].UpdatedAt {
|
||||
return summaries[i].CreatedAt > summaries[j].CreatedAt
|
||||
}
|
||||
return summaries[i].UpdatedAt > summaries[j].UpdatedAt
|
||||
})
|
||||
if s.state.Limit < DisabledLimit || !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
if s.state.Limit == DisabledLimit {
|
||||
s.state.Items = summaries
|
||||
return
|
||||
}
|
||||
if len(summaries) > s.state.Limit {
|
||||
keep := make(map[string]struct{}, s.state.Limit)
|
||||
for _, item := range summaries[:s.state.Limit] {
|
||||
keep[item.ID] = struct{}{}
|
||||
}
|
||||
for id := range s.details {
|
||||
if _, ok := keep[id]; !ok {
|
||||
s.markDetailDeletedLocked(id)
|
||||
delete(s.details, id)
|
||||
}
|
||||
}
|
||||
summaries = summaries[:s.state.Limit]
|
||||
}
|
||||
s.state.Items = summaries
|
||||
}
|
||||
|
||||
func (s *Store) nextRevisionLocked() int64 {
|
||||
next := time.Now().UnixNano()
|
||||
if next <= s.state.Revision {
|
||||
next = s.state.Revision + 1
|
||||
}
|
||||
s.state.Revision = next
|
||||
return next
|
||||
}
|
||||
|
||||
func summaryFromEntry(item Entry) SummaryEntry {
|
||||
return SummaryEntry{
|
||||
ID: item.ID,
|
||||
Revision: item.Revision,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
CompletedAt: item.CompletedAt,
|
||||
Status: item.Status,
|
||||
CallerID: item.CallerID,
|
||||
AccountID: item.AccountID,
|
||||
Model: item.Model,
|
||||
Stream: item.Stream,
|
||||
UserInput: item.UserInput,
|
||||
Preview: buildPreview(item),
|
||||
StatusCode: item.StatusCode,
|
||||
ElapsedMs: item.ElapsedMs,
|
||||
FinishReason: item.FinishReason,
|
||||
DetailRevision: item.Revision,
|
||||
}
|
||||
}
|
||||
|
||||
func buildPreview(item Entry) string {
|
||||
candidate := strings.TrimSpace(item.Content)
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.ReasoningContent)
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.Error)
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.UserInput)
|
||||
}
|
||||
if len(candidate) > defaultPreviewAt {
|
||||
return candidate[:defaultPreviewAt] + "..."
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
func readDetailFile(path string) (Entry, error) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Entry{}, fmt.Errorf("read chat history detail: %w", err)
|
||||
}
|
||||
var env detailEnvelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
return Entry{}, fmt.Errorf("decode chat history detail: %w", err)
|
||||
}
|
||||
return cloneEntry(env.Item), nil
|
||||
}
|
||||
|
||||
func parseLegacy(raw []byte) (legacyFile, bool, error) {
|
||||
var legacy legacyFile
|
||||
if err := json.Unmarshal(raw, &legacy); err != nil {
|
||||
return legacyFile{}, false, nil
|
||||
}
|
||||
if len(legacy.Items) == 0 {
|
||||
return legacy, false, nil
|
||||
}
|
||||
var probe legacyProbe
|
||||
if err := json.Unmarshal(raw, &probe); err == nil {
|
||||
for _, item := range probe.Items {
|
||||
if _, ok := item["detail_revision"]; ok {
|
||||
return legacy, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return legacy, true, nil
|
||||
}
|
||||
|
||||
func writeFileAtomic(path string, body []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if dir == "" {
|
||||
dir = "."
|
||||
}
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history dir: %w", err)
|
||||
}
|
||||
}
|
||||
tmpFile, err := os.CreateTemp(dir, ".chat-history-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp chat history: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := func() error {
|
||||
if err := os.Remove(tmpPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("remove temp chat history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
withCleanup := func(primary error, closeErr error) error {
|
||||
errs := []error{primary}
|
||||
if closeErr != nil {
|
||||
errs = append(errs, fmt.Errorf("close temp chat history: %w", closeErr))
|
||||
}
|
||||
if cleanupErr := cleanup(); cleanupErr != nil {
|
||||
errs = append(errs, cleanupErr)
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
if _, err := tmpFile.Write(body); err != nil {
|
||||
return withCleanup(fmt.Errorf("write temp chat history: %w", err), tmpFile.Close())
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
return withCleanup(fmt.Errorf("sync temp chat history: %w", err), tmpFile.Close())
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
if cleanupErr := cleanup(); cleanupErr != nil {
|
||||
return errors.Join(fmt.Errorf("close temp chat history: %w", err), cleanupErr)
|
||||
}
|
||||
return fmt.Errorf("close temp chat history: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
if cleanupErr := cleanup(); cleanupErr != nil {
|
||||
return errors.Join(fmt.Errorf("promote temp chat history: %w", err), cleanupErr)
|
||||
}
|
||||
return fmt.Errorf("promote temp chat history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListETag(revision int64) string {
|
||||
return fmt.Sprintf(`W/"chat-history-list-%d"`, revision)
|
||||
}
|
||||
|
||||
func DetailETag(id string, revision int64) string {
|
||||
return fmt.Sprintf(`W/"chat-history-detail-%s-%d"`, strings.TrimSpace(id), revision)
|
||||
}
|
||||
|
||||
func isAllowedLimit(limit int) bool {
|
||||
_, ok := allowedLimits[limit]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Store) markDetailDirtyLocked(id string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if s.dirty == nil {
|
||||
s.dirty = map[string]struct{}{}
|
||||
}
|
||||
if s.deleted == nil {
|
||||
s.deleted = map[string]struct{}{}
|
||||
}
|
||||
s.dirty[id] = struct{}{}
|
||||
delete(s.deleted, id)
|
||||
}
|
||||
|
||||
func (s *Store) markDetailDeletedLocked(id string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if s.dirty == nil {
|
||||
s.dirty = map[string]struct{}{}
|
||||
}
|
||||
if s.deleted == nil {
|
||||
s.deleted = map[string]struct{}{}
|
||||
}
|
||||
s.deleted[id] = struct{}{}
|
||||
delete(s.dirty, id)
|
||||
}
|
||||
|
||||
func (s *Store) clearPendingDetailChangesLocked() {
|
||||
s.dirty = map[string]struct{}{}
|
||||
s.deleted = map[string]struct{}{}
|
||||
}
|
||||
|
||||
func sortedDetailIDs(ids map[string]struct{}) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
out = append(out, id)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneFile(in File) File {
|
||||
out := File{
|
||||
Version: in.Version,
|
||||
Limit: in.Limit,
|
||||
Revision: in.Revision,
|
||||
Items: make([]SummaryEntry, len(in.Items)),
|
||||
}
|
||||
copy(out.Items, in.Items)
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneEntry(item Entry) Entry {
|
||||
item.Usage = cloneMap(item.Usage)
|
||||
item.Messages = cloneMessages(item.Messages)
|
||||
return item
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneMessages(messages []Message) []Message {
|
||||
if len(messages) == 0 {
|
||||
return []Message{}
|
||||
}
|
||||
out := make([]Message, len(messages))
|
||||
copy(out, messages)
|
||||
return out
|
||||
}
|
||||
483
internal/chathistory/store_test.go
Normal file
483
internal/chathistory/store_test.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package chathistory
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func blockDetailDir(t *testing.T, detailDir string) func() {
|
||||
t.Helper()
|
||||
blockedDir := detailDir + ".blocked"
|
||||
if err := os.RemoveAll(blockedDir); err != nil {
|
||||
t.Fatalf("remove blocked detail dir failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(detailDir, blockedDir); err != nil {
|
||||
t.Fatalf("move detail dir aside failed: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocked detail path failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(detailDir, []byte("blocked"), 0o644); err != nil {
|
||||
t.Fatalf("write blocked detail path failed: %v", err)
|
||||
}
|
||||
var once sync.Once
|
||||
return func() {
|
||||
t.Helper()
|
||||
once.Do(func() {
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocking detail path failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(blockedDir, detailDir); err != nil {
|
||||
t.Fatalf("restore detail dir failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCreatesAndPersistsEntries(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
started, err := store.Start(StartParams{
|
||||
CallerID: "caller:abc",
|
||||
AccountID: "user@example.com",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start entry failed: %v", err)
|
||||
}
|
||||
|
||||
updated, err := store.Update(started.ID, UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: "thinking",
|
||||
Content: "answer",
|
||||
StatusCode: 200,
|
||||
ElapsedMs: 321,
|
||||
FinishReason: "stop",
|
||||
Usage: map[string]any{"total_tokens": 9},
|
||||
Completed: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("update entry failed: %v", err)
|
||||
}
|
||||
if updated.Status != "success" || updated.Content != "answer" {
|
||||
t.Fatalf("unexpected updated entry: %#v", updated)
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != DefaultLimit {
|
||||
t.Fatalf("unexpected default limit: %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one item, got %d", len(snapshot.Items))
|
||||
}
|
||||
if snapshot.Items[0].CompletedAt == 0 {
|
||||
t.Fatalf("expected completed_at to be populated")
|
||||
}
|
||||
if snapshot.Items[0].Preview != "answer" {
|
||||
t.Fatalf("expected summary preview=answer, got %#v", snapshot.Items[0])
|
||||
}
|
||||
|
||||
reloaded := New(path)
|
||||
reloadedSnapshot, err := reloaded.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("reload snapshot failed: %v", err)
|
||||
}
|
||||
if len(reloadedSnapshot.Items) != 1 {
|
||||
t.Fatalf("unexpected reloaded summaries: %#v", reloadedSnapshot.Items)
|
||||
}
|
||||
full, err := reloaded.Get(started.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "answer" {
|
||||
t.Fatalf("expected detail content=answer, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreTrimsToConfiguredLimit(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
if _, err := store.SetLimit(10); err != nil {
|
||||
t.Fatalf("set limit failed: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
entry, err := store.Start(StartParams{Model: "deepseek-chat", UserInput: "msg"})
|
||||
if err != nil {
|
||||
t.Fatalf("start %d failed: %v", i, err)
|
||||
}
|
||||
if _, err := store.Update(entry.ID, UpdateParams{Status: "success", Content: "ok", Completed: true}); err != nil {
|
||||
t.Fatalf("update %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 10 {
|
||||
t.Fatalf("expected 10 items, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDeleteClearAndLimitValidation(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
entry, err := store.Start(StartParams{UserInput: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("start failed: %v", err)
|
||||
}
|
||||
if err := store.Delete(entry.ID); err != nil {
|
||||
t.Fatalf("delete failed: %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected empty items after delete, got %d", len(snapshot.Items))
|
||||
}
|
||||
if _, err := store.SetLimit(999); err == nil {
|
||||
t.Fatalf("expected invalid limit error")
|
||||
}
|
||||
if err := store.Clear(); err != nil {
|
||||
t.Fatalf("clear failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDisablePreservesHistoryAndBlocksNewEntries(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
entry, err := store.Start(StartParams{UserInput: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("start failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(entry.ID, UpdateParams{Status: "success", Content: "world", Completed: true}); err != nil {
|
||||
t.Fatalf("update failed: %v", err)
|
||||
}
|
||||
|
||||
snapshot, err := store.SetLimit(DisabledLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("disable failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != DisabledLimit {
|
||||
t.Fatalf("expected disabled limit, got %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected disabled mode to preserve summaries, got %d", len(snapshot.Items))
|
||||
}
|
||||
if store.Enabled() {
|
||||
t.Fatalf("expected store to report disabled")
|
||||
}
|
||||
if _, err := store.Start(StartParams{UserInput: "later"}); err != ErrDisabled {
|
||||
t.Fatalf("expected ErrDisabled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreConcurrentUpdatesKeepSplitFilesValid(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
entry, err := store.Start(StartParams{
|
||||
CallerID: "caller:test",
|
||||
Model: "deepseek-chat",
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("start failed: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = store.Update(entry.ID, UpdateParams{
|
||||
Status: "success",
|
||||
Content: "answer",
|
||||
ElapsedMs: int64(idx),
|
||||
Completed: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("update failed: %v", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 8 {
|
||||
t.Fatalf("expected 8 items, got %d", len(snapshot.Items))
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read index failed: %v", err)
|
||||
}
|
||||
var persisted File
|
||||
if err := json.Unmarshal(raw, &persisted); err != nil {
|
||||
t.Fatalf("persisted index is invalid json: %v", err)
|
||||
}
|
||||
if len(persisted.Items) != 8 {
|
||||
t.Fatalf("expected persisted items=8, got %d", len(persisted.Items))
|
||||
}
|
||||
|
||||
detailFiles, err := os.ReadDir(path + ".d")
|
||||
if err != nil {
|
||||
t.Fatalf("read detail dir failed: %v", err)
|
||||
}
|
||||
if len(detailFiles) != 8 {
|
||||
t.Fatalf("expected 8 detail files, got %d", len(detailFiles))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAutoMigratesLegacyMonolith(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: "chat_legacy",
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "success",
|
||||
UserInput: "hello",
|
||||
Content: "world",
|
||||
ReasoningContent: "thinking",
|
||||
}},
|
||||
}
|
||||
body, _ := json.MarshalIndent(legacy, "", " ")
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected legacy migration success, got %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one migrated summary, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get("chat_legacy")
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "world" {
|
||||
t.Fatalf("expected migrated detail content preserved, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAutoMigratesMetadataOnlyLegacyMonolith(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: "chat_metadata_only",
|
||||
Revision: 0,
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "error",
|
||||
CallerID: "caller:test",
|
||||
AccountID: "acct:test",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
Error: "boom",
|
||||
StatusCode: 500,
|
||||
ElapsedMs: 12,
|
||||
FinishReason: "error",
|
||||
}},
|
||||
}
|
||||
body, _ := json.MarshalIndent(legacy, "", " ")
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected legacy metadata-only migration success, got %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one migrated summary, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get("chat_metadata_only")
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Error != "boom" || full.UserInput != "hello" {
|
||||
t.Fatalf("expected metadata-only legacy fields preserved, got %#v", full)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(store.DetailDir(), "chat_metadata_only.json")); err != nil {
|
||||
t.Fatalf("expected migrated detail file to exist: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreLegacyMigrationBestEffortWhenRewriteFails(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
longID := "chat_" + strings.Repeat("x", 320)
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: longID,
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "success",
|
||||
UserInput: "hello",
|
||||
Content: "world",
|
||||
}},
|
||||
}
|
||||
body, err := json.MarshalIndent(legacy, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy file failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected store to stay usable after migration writeback failure, got %v", err)
|
||||
}
|
||||
if !store.Enabled() {
|
||||
t.Fatal("expected store to remain enabled after best-effort migration")
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 || snapshot.Items[0].ID != longID {
|
||||
t.Fatalf("unexpected snapshot after best-effort migration: %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get(longID)
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "world" {
|
||||
t.Fatalf("expected migrated content to stay in memory, got %#v", full)
|
||||
}
|
||||
if _, statErr := os.Stat(filepath.Join(store.DetailDir(), longID+".json")); statErr == nil {
|
||||
t.Fatal("expected detail write to fail for overlong legacy id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreTransientPersistenceFailureDoesNotLatch(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
first, err := store.Start(StartParams{UserInput: "first"})
|
||||
if err != nil {
|
||||
t.Fatalf("start first failed: %v", err)
|
||||
}
|
||||
restore := blockDetailDir(t, store.DetailDir())
|
||||
t.Cleanup(restore)
|
||||
|
||||
blocked, err := store.Start(StartParams{UserInput: "blocked"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected start failure while detail dir is blocked")
|
||||
}
|
||||
if blocked.ID == "" {
|
||||
t.Fatalf("expected in-memory entry from failed start")
|
||||
}
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("transient start failure should not latch store error: %v", err)
|
||||
}
|
||||
if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "one", Completed: true}); err == nil {
|
||||
t.Fatalf("expected update failure while detail dir is blocked")
|
||||
}
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("transient update failure should not latch store error: %v", err)
|
||||
}
|
||||
|
||||
restore()
|
||||
|
||||
if _, err := store.Update(blocked.ID, UpdateParams{Status: "success", Content: "two", Completed: true}); err != nil {
|
||||
t.Fatalf("update after restore failed: %v", err)
|
||||
}
|
||||
if _, err := store.Start(StartParams{UserInput: "later"}); err != nil {
|
||||
t.Fatalf("start after restore failed: %v", err)
|
||||
}
|
||||
full, err := store.Get(blocked.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get restored entry failed: %v", err)
|
||||
}
|
||||
if full.Content != "two" || full.Status != "success" {
|
||||
t.Fatalf("expected restored entry persisted, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreWritesOnlyChangedDetailFiles(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
first, err := store.Start(StartParams{UserInput: "one"})
|
||||
if err != nil {
|
||||
t.Fatalf("start first failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "first", Completed: true}); err != nil {
|
||||
t.Fatalf("update first failed: %v", err)
|
||||
}
|
||||
second, err := store.Start(StartParams{UserInput: "two"})
|
||||
if err != nil {
|
||||
t.Fatalf("start second failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(second.ID, UpdateParams{Status: "success", Content: "second", Completed: true}); err != nil {
|
||||
t.Fatalf("update second failed: %v", err)
|
||||
}
|
||||
|
||||
firstPath := filepath.Join(store.DetailDir(), first.ID+".json")
|
||||
secondPath := filepath.Join(store.DetailDir(), second.ID+".json")
|
||||
beforeFirst, err := os.ReadFile(firstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first detail before update failed: %v", err)
|
||||
}
|
||||
beforeSecond, err := os.ReadFile(secondPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second detail before update failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "first-updated", Completed: true}); err != nil {
|
||||
t.Fatalf("update first again failed: %v", err)
|
||||
}
|
||||
|
||||
afterFirst, err := os.ReadFile(firstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first detail after update failed: %v", err)
|
||||
}
|
||||
afterSecond, err := os.ReadFile(secondPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second detail after update failed: %v", err)
|
||||
}
|
||||
|
||||
if bytes.Equal(beforeFirst, afterFirst) {
|
||||
t.Fatalf("expected first detail file to change after update")
|
||||
}
|
||||
if !bytes.Equal(beforeSecond, afterSecond) {
|
||||
t.Fatalf("expected untouched detail file to remain byte-identical")
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"ds2api/internal/toolcall"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
@@ -65,55 +63,6 @@ func TestGoCompatSSEFixtures(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("glob toolcall fixtures failed: %v", err)
|
||||
}
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no toolcall fixtures found")
|
||||
}
|
||||
for _, fixturePath := range files {
|
||||
name := trimExt(filepath.Base(fixturePath))
|
||||
expectedPath := compatPath("expected", "toolcalls_"+name+".json")
|
||||
|
||||
var fixture struct {
|
||||
Text string `json:"text"`
|
||||
ToolNames []string `json:"tool_names"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
var expected struct {
|
||||
Calls []toolcall.ParsedToolCall `json:"calls"`
|
||||
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
|
||||
RejectedByPolicy bool `json:"rejectedByPolicy"`
|
||||
RejectedToolNames []string `json:"rejectedToolNames"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
var got toolcall.ToolCallParseResult
|
||||
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||
case "standalone":
|
||||
got = toolcall.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
default:
|
||||
got = toolcall.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
}
|
||||
if got.Calls == nil {
|
||||
got.Calls = []toolcall.ParsedToolCall{}
|
||||
}
|
||||
if got.RejectedToolNames == nil {
|
||||
got.RejectedToolNames = []string{}
|
||||
}
|
||||
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
|
||||
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
|
||||
got.RejectedByPolicy != expected.RejectedByPolicy ||
|
||||
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoCompatTokenFixtures(t *testing.T) {
|
||||
var fixture struct {
|
||||
Cases []struct {
|
||||
|
||||
@@ -17,6 +17,9 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if len(c.Keys) > 0 {
|
||||
m["keys"] = c.Keys
|
||||
}
|
||||
if len(c.APIKeys) > 0 {
|
||||
m["api_keys"] = c.APIKeys
|
||||
}
|
||||
if len(c.Accounts) > 0 {
|
||||
m["accounts"] = c.Accounts
|
||||
}
|
||||
@@ -48,6 +51,9 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
m["auto_delete"] = c.AutoDelete
|
||||
if c.HistorySplit.Enabled != nil || c.HistorySplit.TriggerAfterTurns != nil {
|
||||
m["history_split"] = c.HistorySplit
|
||||
}
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
@@ -69,6 +75,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.Keys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "api_keys":
|
||||
if err := json.Unmarshal(v, &c.APIKeys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "accounts":
|
||||
if err := json.Unmarshal(v, &c.Accounts); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -115,6 +125,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.AutoDelete); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "history_split":
|
||||
if err := json.Unmarshal(v, &c.HistorySplit); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -130,12 +144,14 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.NormalizeCredentials()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
APIKeys: slices.Clone(c.APIKeys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
Proxies: slices.Clone(c.Proxies),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
@@ -147,9 +163,13 @@ func (c Config) Clone() Config {
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
StripReferenceMarkers: cloneBoolPtr(c.Compat.StripReferenceMarkers),
|
||||
},
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
AutoDelete: c.AutoDelete,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
AutoDelete: c.AutoDelete,
|
||||
HistorySplit: HistorySplitConfig{
|
||||
Enabled: cloneBoolPtr(c.HistorySplit.Enabled),
|
||||
TriggerAfterTurns: cloneIntPtr(c.HistorySplit.TriggerAfterTurns),
|
||||
},
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
@@ -179,6 +199,14 @@ func cloneBoolPtr(in *bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func cloneIntPtr(in *int) *int {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
v := *in
|
||||
return &v
|
||||
}
|
||||
|
||||
func parseConfigString(raw string) (Config, error) {
|
||||
var cfg Config
|
||||
candidates := []string{raw}
|
||||
|
||||
@@ -8,24 +8,28 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
Proxies []Proxy `json:"proxies,omitempty"`
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
ModelAliases map[string]string `json:"model_aliases,omitempty"`
|
||||
Admin AdminConfig `json:"admin,omitempty"`
|
||||
Runtime RuntimeConfig `json:"runtime,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
Proxies []Proxy `json:"proxies,omitempty"`
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
ModelAliases map[string]string `json:"model_aliases,omitempty"`
|
||||
Admin AdminConfig `json:"admin,omitempty"`
|
||||
Runtime RuntimeConfig `json:"runtime,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||
HistorySplit HistorySplitConfig `json:"history_split"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
@@ -33,6 +37,12 @@ type Account struct {
|
||||
ProxyID string `json:"proxy_id,omitempty"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -73,6 +83,25 @@ func (c *Config) ClearAccountTokens() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) NormalizeCredentials() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
normalizedAPIKeys := normalizeAPIKeys(c.APIKeys)
|
||||
if len(normalizedAPIKeys) > 0 {
|
||||
c.APIKeys = normalizedAPIKeys
|
||||
c.Keys = apiKeysToStrings(c.APIKeys)
|
||||
} else {
|
||||
c.Keys = normalizeKeys(c.Keys)
|
||||
c.APIKeys = apiKeysFromStrings(c.Keys, nil)
|
||||
}
|
||||
|
||||
for i := range c.Accounts {
|
||||
c.Accounts[i].Name = strings.TrimSpace(c.Accounts[i].Name)
|
||||
c.Accounts[i].Remark = strings.TrimSpace(c.Accounts[i].Remark)
|
||||
}
|
||||
}
|
||||
|
||||
// DropInvalidAccounts removes accounts that cannot be addressed by admin APIs
|
||||
// (no email and no normalizable mobile). This prevents legacy token-only
|
||||
// records from becoming orphaned empty entries after token stripping.
|
||||
@@ -120,3 +149,8 @@ type AutoDeleteConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Sessions bool `json:"sessions,omitempty"`
|
||||
}
|
||||
|
||||
type HistorySplitConfig struct {
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
TriggerAfterTurns *int `json:"trigger_after_turns,omitempty"`
|
||||
}
|
||||
|
||||
@@ -154,6 +154,10 @@ func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
AutoDelete: AutoDeleteConfig{
|
||||
Mode: "single",
|
||||
},
|
||||
HistorySplit: HistorySplitConfig{
|
||||
Enabled: &trueVal,
|
||||
TriggerAfterTurns: func() *int { v := 2; return &v }(),
|
||||
},
|
||||
Runtime: RuntimeConfig{
|
||||
TokenRefreshIntervalHours: 12,
|
||||
},
|
||||
@@ -193,6 +197,12 @@ func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
if decoded.AutoDelete.Mode != "single" {
|
||||
t.Fatalf("unexpected auto delete mode: %#v", decoded.AutoDelete.Mode)
|
||||
}
|
||||
if decoded.HistorySplit.Enabled == nil || !*decoded.HistorySplit.Enabled {
|
||||
t.Fatalf("unexpected history split enabled: %#v", decoded.HistorySplit.Enabled)
|
||||
}
|
||||
if decoded.HistorySplit.TriggerAfterTurns == nil || *decoded.HistorySplit.TriggerAfterTurns != 2 {
|
||||
t.Fatalf("unexpected history split trigger_after_turns: %#v", decoded.HistorySplit.TriggerAfterTurns)
|
||||
}
|
||||
if decoded.Compat.WideInputStrictOutput == nil || !*decoded.Compat.WideInputStrictOutput {
|
||||
t.Fatalf("unexpected compat wide_input_strict_output: %#v", decoded.Compat.WideInputStrictOutput)
|
||||
}
|
||||
@@ -249,6 +259,8 @@ func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) {
|
||||
|
||||
func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
falseVal := false
|
||||
trueVal := true
|
||||
turns := 2
|
||||
cfg := Config{
|
||||
Keys: []string{"key1"},
|
||||
Accounts: []Account{{Email: "user@test.com", Token: "token"}},
|
||||
@@ -258,6 +270,10 @@ func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
Compat: CompatConfig{
|
||||
StripReferenceMarkers: &falseVal,
|
||||
},
|
||||
HistorySplit: HistorySplitConfig{
|
||||
Enabled: &trueVal,
|
||||
TriggerAfterTurns: &turns,
|
||||
},
|
||||
AdditionalFields: map[string]any{"custom": "value"},
|
||||
}
|
||||
|
||||
@@ -270,6 +286,12 @@ func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
if cfg.Compat.StripReferenceMarkers != nil {
|
||||
*cfg.Compat.StripReferenceMarkers = true
|
||||
}
|
||||
if cfg.HistorySplit.Enabled != nil {
|
||||
*cfg.HistorySplit.Enabled = false
|
||||
}
|
||||
if cfg.HistorySplit.TriggerAfterTurns != nil {
|
||||
*cfg.HistorySplit.TriggerAfterTurns = 5
|
||||
}
|
||||
|
||||
// Cloned should not be affected
|
||||
if cloned.Keys[0] != "key1" {
|
||||
@@ -284,6 +306,12 @@ func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
if cloned.Compat.StripReferenceMarkers == nil || *cloned.Compat.StripReferenceMarkers {
|
||||
t.Fatalf("clone compat was affected: %#v", cloned.Compat.StripReferenceMarkers)
|
||||
}
|
||||
if cloned.HistorySplit.Enabled == nil || !*cloned.HistorySplit.Enabled {
|
||||
t.Fatalf("clone history split enabled was affected: %#v", cloned.HistorySplit.Enabled)
|
||||
}
|
||||
if cloned.HistorySplit.TriggerAfterTurns == nil || *cloned.HistorySplit.TriggerAfterTurns != 2 {
|
||||
t.Fatalf("clone history split trigger was affected: %#v", cloned.HistorySplit.TriggerAfterTurns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCloneNilMaps(t *testing.T) {
|
||||
@@ -529,6 +557,101 @@ func TestStoreUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateReconcilesAPIKeyMutations(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
store := LoadStore()
|
||||
|
||||
if err := store.Update(func(cfg *Config) error {
|
||||
cfg.APIKeys = append(cfg.APIKeys, APIKey{Key: "k2", Name: "secondary", Remark: "staging"})
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("add api key failed: %v", err)
|
||||
}
|
||||
|
||||
snap := store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "k1" || snap.Keys[1] != "k2" {
|
||||
t.Fatalf("unexpected keys after api key add: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys length after add: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("metadata for existing key was lost: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("metadata for new key was lost: %#v", snap.APIKeys[1])
|
||||
}
|
||||
|
||||
if err := store.Update(func(cfg *Config) error {
|
||||
cfg.APIKeys = append([]APIKey(nil), cfg.APIKeys[1:]...)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("delete api key failed: %v", err)
|
||||
}
|
||||
|
||||
snap = store.Snapshot()
|
||||
if len(snap.Keys) != 1 || snap.Keys[0] != "k2" {
|
||||
t.Fatalf("unexpected keys after api key delete: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 1 || snap.APIKeys[0].Key != "k2" {
|
||||
t.Fatalf("unexpected api keys after delete: %#v", snap.APIKeys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateReconcilesLegacyKeyMutations(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
store := LoadStore()
|
||||
|
||||
if err := store.Update(func(cfg *Config) error {
|
||||
cfg.Keys = append(cfg.Keys, "k2")
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("legacy key update failed: %v", err)
|
||||
}
|
||||
|
||||
snap := store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "k1" || snap.Keys[1] != "k2" {
|
||||
t.Fatalf("unexpected keys after legacy update: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after legacy update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("metadata for preserved key was lost: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Key != "k2" || snap.APIKeys[1].Name != "" || snap.APIKeys[1].Remark != "" {
|
||||
t.Fatalf("new legacy key should stay metadata-free: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCredentialsPrefersStructuredAPIKeys(t *testing.T) {
|
||||
cfg := Config{
|
||||
Keys: []string{"legacy-key"},
|
||||
APIKeys: []APIKey{
|
||||
{Key: "structured-key", Name: "primary", Remark: "prod"},
|
||||
},
|
||||
}
|
||||
cfg.NormalizeCredentials()
|
||||
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "structured-key" {
|
||||
t.Fatalf("unexpected normalized keys: %#v", cfg.Keys)
|
||||
}
|
||||
if len(cfg.APIKeys) != 1 {
|
||||
t.Fatalf("unexpected normalized api keys: %#v", cfg.APIKeys)
|
||||
}
|
||||
if cfg.APIKeys[0].Key != "structured-key" || cfg.APIKeys[0].Name != "primary" || cfg.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("unexpected structured api key metadata: %#v", cfg.APIKeys[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreClaudeMapping(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
|
||||
store := LoadStore()
|
||||
|
||||
158
internal/config/credentials.go
Normal file
158
internal/config/credentials.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Config) ReconcileCredentials(base Config) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
currKeys := normalizeKeys(c.Keys)
|
||||
currAPIKeys := normalizeAPIKeys(c.APIKeys)
|
||||
baseKeys := normalizeKeys(base.Keys)
|
||||
baseAPIKeys := normalizeAPIKeys(base.APIKeys)
|
||||
|
||||
keysChanged := !slices.Equal(currKeys, baseKeys)
|
||||
apiKeysChanged := !equalAPIKeys(currAPIKeys, baseAPIKeys)
|
||||
|
||||
if keysChanged && !apiKeysChanged {
|
||||
c.APIKeys = apiKeysFromStrings(currKeys, apiKeyMap(baseAPIKeys))
|
||||
} else {
|
||||
c.APIKeys = currAPIKeys
|
||||
}
|
||||
c.Keys = apiKeysToStrings(c.APIKeys)
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, key)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeAPIKeys(items []APIKey) []APIKey {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]APIKey, 0, len(items))
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, APIKey{
|
||||
Key: key,
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func apiKeysFromStrings(keys []string, meta map[string]APIKey) []APIKey {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]APIKey, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
if item, ok := meta[key]; ok {
|
||||
out = append(out, APIKey{
|
||||
Key: key,
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
})
|
||||
continue
|
||||
}
|
||||
out = append(out, APIKey{Key: key})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func apiKeysToStrings(items []APIKey) []string {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func apiKeyMap(items []APIKey) map[string]APIKey {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]APIKey, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := out[key]; ok {
|
||||
continue
|
||||
}
|
||||
out[key] = APIKey{
|
||||
Key: key,
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalAPIKeys(a, b []APIKey) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
return slices.EqualFunc(a, b, func(x, y APIKey) bool {
|
||||
return strings.TrimSpace(x.Key) == strings.TrimSpace(y.Key) &&
|
||||
strings.TrimSpace(x.Name) == strings.TrimSpace(y.Name) &&
|
||||
strings.TrimSpace(x.Remark) == strings.TrimSpace(y.Remark)
|
||||
})
|
||||
}
|
||||
@@ -37,6 +37,10 @@ func RawStreamSampleRoot() string {
|
||||
return ResolvePath("DS2API_RAW_STREAM_SAMPLE_ROOT", "tests/raw_stream_samples")
|
||||
}
|
||||
|
||||
func ChatHistoryPath() string {
|
||||
return ResolvePath("DS2API_CHAT_HISTORY_PATH", "data/chat_history.json")
|
||||
}
|
||||
|
||||
func StaticAdminDir() string {
|
||||
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func LoadStoreWithError() (*Store, error) {
|
||||
|
||||
func loadStore() (*Store, error) {
|
||||
cfg, fromEnv, err := loadConfig()
|
||||
cfg.NormalizeCredentials()
|
||||
if validateErr := ValidateConfig(cfg); validateErr != nil {
|
||||
err = errors.Join(err, validateErr)
|
||||
}
|
||||
@@ -112,6 +113,7 @@ func loadConfigFromFile(path string) (Config, error) {
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.NormalizeCredentials()
|
||||
cfg.DropInvalidAccounts()
|
||||
if strings.Contains(string(content), `"test_status"`) && !IsVercel() {
|
||||
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
|
||||
@@ -207,6 +209,7 @@ func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
func (s *Store) Replace(cfg Config) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg.NormalizeCredentials()
|
||||
s.cfg = cfg.Clone()
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
@@ -215,10 +218,13 @@ func (s *Store) Replace(cfg Config) error {
|
||||
func (s *Store) Update(mutator func(*Config) error) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.cfg.Clone()
|
||||
base := s.cfg.Clone()
|
||||
cfg := base.Clone()
|
||||
if err := mutator(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.ReconcileCredentials(base)
|
||||
cfg.NormalizeCredentials()
|
||||
s.cfg = cfg
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
|
||||
@@ -174,3 +174,21 @@ func (s *Store) RuntimeTokenRefreshIntervalHours() int {
|
||||
func (s *Store) AutoDeleteSessions() bool {
|
||||
return s.AutoDeleteMode() != "none"
|
||||
}
|
||||
|
||||
func (s *Store) HistorySplitEnabled() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.HistorySplit.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *s.cfg.HistorySplit.Enabled
|
||||
}
|
||||
|
||||
func (s *Store) HistorySplitTriggerAfterTurns() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.HistorySplit.TriggerAfterTurns == nil || *s.cfg.HistorySplit.TriggerAfterTurns <= 0 {
|
||||
return 1
|
||||
}
|
||||
return *s.cfg.HistorySplit.TriggerAfterTurns
|
||||
}
|
||||
|
||||
27
internal/config/store_accessors_test.go
Normal file
27
internal/config/store_accessors_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestStoreHistorySplitAccessors(t *testing.T) {
|
||||
store := &Store{cfg: Config{}}
|
||||
if !store.HistorySplitEnabled() {
|
||||
t.Fatal("expected history split enabled by default")
|
||||
}
|
||||
if got := store.HistorySplitTriggerAfterTurns(); got != 1 {
|
||||
t.Fatalf("default history split trigger_after_turns=%d want=1", got)
|
||||
}
|
||||
|
||||
enabled := false
|
||||
turns := 3
|
||||
store.cfg.HistorySplit = HistorySplitConfig{
|
||||
Enabled: &enabled,
|
||||
TriggerAfterTurns: &turns,
|
||||
}
|
||||
|
||||
if store.HistorySplitEnabled() {
|
||||
t.Fatal("expected history split disabled after override")
|
||||
}
|
||||
if got := store.HistorySplitTriggerAfterTurns(); got != 3 {
|
||||
t.Fatalf("history split trigger_after_turns=%d want=3", got)
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,9 @@ func ValidateConfig(c Config) error {
|
||||
if err := ValidateAutoDeleteConfig(c.AutoDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateHistorySplitConfig(c.HistorySplit); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAccountProxyReferences(c.Accounts, c.Proxies); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -111,6 +114,15 @@ func ValidateAutoDeleteConfig(autoDelete AutoDeleteConfig) error {
|
||||
return ValidateAutoDeleteMode(autoDelete.Mode)
|
||||
}
|
||||
|
||||
func ValidateHistorySplitConfig(historySplit HistorySplitConfig) error {
|
||||
if historySplit.TriggerAfterTurns != nil {
|
||||
if err := ValidateIntRange("history_split.trigger_after_turns", *historySplit.TriggerAfterTurns, 1, 1000, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateIntRange(name string, value, min, max int, required bool) error {
|
||||
if value == 0 && !required {
|
||||
return nil
|
||||
|
||||
@@ -39,6 +39,13 @@ func TestValidateConfigRejectsInvalidValues(t *testing.T) {
|
||||
cfg: Config{AutoDelete: AutoDeleteConfig{Mode: "maybe"}},
|
||||
want: "auto_delete.mode",
|
||||
},
|
||||
{
|
||||
name: "history split",
|
||||
cfg: Config{HistorySplit: HistorySplitConfig{
|
||||
TriggerAfterTurns: intPtr(0),
|
||||
}},
|
||||
want: "history_split.trigger_after_turns",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -59,3 +66,5 @@ func TestValidateConfigAcceptsLegacyAutoDeleteSessions(t *testing.T) {
|
||||
t.Fatalf("expected legacy auto_delete.sessions config to remain valid, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
@@ -2,32 +2,6 @@ package claude
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildMessageResponseDetectsToolCallsFromThinkingFallback(t *testing.T) {
|
||||
resp := BuildMessageResponse(
|
||||
"msg_1",
|
||||
"claude-sonnet-4-5",
|
||||
[]any{map[string]any{"role": "user", "content": "hi"}},
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
|
||||
"",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
if resp["stop_reason"] != "tool_use" {
|
||||
t.Fatalf("expected stop_reason=tool_use, got=%#v", resp["stop_reason"])
|
||||
}
|
||||
content, _ := resp["content"].([]map[string]any)
|
||||
if len(content) < 2 {
|
||||
t.Fatalf("expected thinking + tool_use content blocks, got=%#v", resp["content"])
|
||||
}
|
||||
last := content[len(content)-1]
|
||||
if last["type"] != "tool_use" {
|
||||
t.Fatalf("expected last content block tool_use, got=%#v", last["type"])
|
||||
}
|
||||
if last["name"] != "search" {
|
||||
t.Fatalf("expected tool name search, got=%#v", last["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageResponseSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
|
||||
resp := BuildMessageResponse(
|
||||
"msg_1",
|
||||
|
||||
@@ -1,75 +1,10 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`,
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText)
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
}
|
||||
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected first output item type function_call, got %#v", first["type"])
|
||||
}
|
||||
if first["call_id"] == "" {
|
||||
t.Fatalf("expected function_call item to have call_id, got %#v", first)
|
||||
}
|
||||
if first["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", first["name"])
|
||||
}
|
||||
argsRaw, _ := first["arguments"].(string)
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
|
||||
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
|
||||
}
|
||||
if args["q"] != "golang" {
|
||||
t.Fatalf("unexpected arguments: %#v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectPromotesMixedProseToolPayloadToFunctionCall(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
`示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`,
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one function_call output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output type, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
|
||||
@@ -18,6 +18,7 @@ const {
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
resetStreamToolCallState,
|
||||
} = require('./toolcall_policy');
|
||||
const {
|
||||
estimateTokens,
|
||||
@@ -115,6 +116,7 @@ module.exports.__test = {
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
resetStreamToolCallState,
|
||||
estimateTokens,
|
||||
buildUsage,
|
||||
filterLeakedContentFilterParts,
|
||||
|
||||
@@ -7,6 +7,53 @@ const {
|
||||
SKIP_EXACT_PATHS,
|
||||
} = require('../shared/deepseek-constants');
|
||||
|
||||
|
||||
|
||||
function stripThinkTags(text) {
|
||||
if (typeof text !== 'string' || !text) {
|
||||
return text;
|
||||
}
|
||||
return text.replace(/<\/?\s*think\s*>/gi, '');
|
||||
}
|
||||
|
||||
function splitThinkingParts(parts) {
|
||||
const out = [];
|
||||
let thinkingDone = false;
|
||||
for (const p of parts) {
|
||||
if (!p) continue;
|
||||
if (thinkingDone && p.type === 'thinking') {
|
||||
const cleaned = stripThinkTags(p.text);
|
||||
if (cleaned) {
|
||||
out.push({ text: cleaned, type: 'text' });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (p.type !== 'thinking') {
|
||||
const cleaned = stripThinkTags(p.text);
|
||||
if (cleaned) {
|
||||
out.push({ text: cleaned, type: p.type });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const match = /<\/\s*think\s*>/i.exec(p.text);
|
||||
if (!match) {
|
||||
out.push(p);
|
||||
continue;
|
||||
}
|
||||
thinkingDone = true;
|
||||
const before = p.text.substring(0, match.index);
|
||||
let after = p.text.substring(match.index + match[0].length);
|
||||
if (before) {
|
||||
out.push({ text: before, type: 'thinking' });
|
||||
}
|
||||
after = stripThinkTags(after);
|
||||
if (after) {
|
||||
out.push({ text: after, type: 'text' });
|
||||
}
|
||||
}
|
||||
return { parts: out, transitioned: thinkingDone };
|
||||
}
|
||||
|
||||
function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenceMarkers = true) {
|
||||
if (!chunk || typeof chunk !== 'object') {
|
||||
return {
|
||||
@@ -147,7 +194,11 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc
|
||||
|
||||
let partType = 'text';
|
||||
if (pathValue === 'response/thinking_content') {
|
||||
partType = 'thinking';
|
||||
if (newType === 'text') {
|
||||
partType = 'text';
|
||||
} else {
|
||||
partType = 'thinking';
|
||||
}
|
||||
} else if (pathValue === 'response/content') {
|
||||
partType = 'text';
|
||||
} else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) {
|
||||
@@ -186,9 +237,16 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc
|
||||
if (content) {
|
||||
parts.push({ text: content, type: partType });
|
||||
}
|
||||
|
||||
let resolvedParts = filterLeakedContentFilterParts(parts);
|
||||
const splitResult = splitThinkingParts(resolvedParts);
|
||||
if (splitResult.transitioned) {
|
||||
newType = 'text';
|
||||
}
|
||||
|
||||
return {
|
||||
parsed: true,
|
||||
parts: filterLeakedContentFilterParts(parts),
|
||||
parts: splitResult.parts,
|
||||
finished: false,
|
||||
contentFilter: false,
|
||||
errorMessage: '',
|
||||
@@ -213,9 +271,16 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc
|
||||
};
|
||||
}
|
||||
parts.push(...extracted.parts);
|
||||
|
||||
let resolvedParts = filterLeakedContentFilterParts(parts);
|
||||
const splitResult = splitThinkingParts(resolvedParts);
|
||||
if (splitResult.transitioned) {
|
||||
newType = 'text';
|
||||
}
|
||||
|
||||
return {
|
||||
parsed: true,
|
||||
parts: filterLeakedContentFilterParts(parts),
|
||||
parts: splitResult.parts,
|
||||
finished: false,
|
||||
contentFilter: false,
|
||||
errorMessage: '',
|
||||
@@ -249,9 +314,16 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let resolvedParts = filterLeakedContentFilterParts(parts);
|
||||
const splitResult = splitThinkingParts(resolvedParts);
|
||||
if (splitResult.transitioned) {
|
||||
newType = 'text';
|
||||
}
|
||||
|
||||
return {
|
||||
parsed: true,
|
||||
parts: filterLeakedContentFilterParts(parts),
|
||||
parts: splitResult.parts,
|
||||
finished: false,
|
||||
contentFilter: false,
|
||||
errorMessage: '',
|
||||
@@ -546,4 +618,5 @@ module.exports = {
|
||||
isFragmentStatusPath,
|
||||
isCitation,
|
||||
stripReferenceMarkers: stripReferenceMarkersText,
|
||||
stripThinkTags,
|
||||
};
|
||||
|
||||
@@ -98,6 +98,15 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName
|
||||
return out;
|
||||
}
|
||||
|
||||
function resetStreamToolCallState(idStore, seenNames) {
|
||||
if (idStore instanceof Map) {
|
||||
idStore.clear();
|
||||
}
|
||||
if (seenNames instanceof Map) {
|
||||
seenNames.clear();
|
||||
}
|
||||
}
|
||||
|
||||
function ensureStreamToolCallID(idStore, index) {
|
||||
const key = Number.isInteger(index) ? index : 0;
|
||||
const existing = idStore.get(key);
|
||||
@@ -135,4 +144,5 @@ module.exports = {
|
||||
boolDefaultTrue,
|
||||
formatIncrementalToolCallDeltas,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
resetStreamToolCallState,
|
||||
};
|
||||
|
||||
@@ -18,6 +18,7 @@ const {
|
||||
formatIncrementalToolCallDeltas,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
boolDefaultTrue,
|
||||
resetStreamToolCallState,
|
||||
} = require('./toolcall_policy');
|
||||
const { createChatCompletionEmitter } = require('./stream_emitter');
|
||||
const {
|
||||
@@ -161,6 +162,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
if (evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0) {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||
resetStreamToolCallState(streamToolCallIDs, streamToolNames);
|
||||
continue;
|
||||
}
|
||||
if (evt.text) {
|
||||
@@ -283,6 +285,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
if (evt.type === 'tool_calls') {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||
resetStreamToolCallState(streamToolCallIDs, streamToolNames);
|
||||
continue;
|
||||
}
|
||||
if (evt.text) {
|
||||
|
||||
@@ -4,15 +4,10 @@ const {
|
||||
toStringSafe,
|
||||
} = require('./state');
|
||||
const {
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
parseMarkupToolCalls,
|
||||
parseTextKVToolCalls,
|
||||
stripFencedCodeBlocks,
|
||||
} = require('./parse_payload');
|
||||
const { TOOL_SEGMENT_KEYWORDS } = require('./tool-keywords');
|
||||
|
||||
const TOOL_NAME_LOOSE_PATTERN = /[^a-z0-9]+/g;
|
||||
const TOOL_MARKUP_PREFIXES = ['<tool_call', '<function_call', '<invoke'];
|
||||
|
||||
function extractToolNames(tools) {
|
||||
@@ -51,47 +46,12 @@ function parseToolCallsDetailed(text, toolNames) {
|
||||
return result;
|
||||
}
|
||||
|
||||
const candidates = buildToolCallCandidates(normalized);
|
||||
for (const c of candidates) {
|
||||
if (!isLikelyJSONToolPayloadCandidate(c)) {
|
||||
continue;
|
||||
}
|
||||
const jsonParsed = parseToolCallsPayload(c);
|
||||
if (jsonParsed.length === 0) {
|
||||
continue;
|
||||
}
|
||||
result.sawToolCallSyntax = true;
|
||||
const filteredJSON = filterToolCallsDetailed(jsonParsed, toolNames);
|
||||
result.calls = filteredJSON.calls;
|
||||
result.rejectedToolNames = filteredJSON.rejectedToolNames;
|
||||
result.rejectedByPolicy = filteredJSON.rejectedToolNames.length > 0 && filteredJSON.calls.length === 0;
|
||||
// XML markup parsing only.
|
||||
const parsed = parseMarkupToolCalls(normalized);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseMarkupToolCalls(c);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(c);
|
||||
}
|
||||
if (parsed.length > 0) {
|
||||
result.sawToolCallSyntax = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(normalized);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(normalized);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
result.sawToolCallSyntax = true;
|
||||
}
|
||||
|
||||
result.sawToolCallSyntax = true;
|
||||
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||
result.calls = filtered.calls;
|
||||
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||
@@ -113,43 +73,11 @@ function parseStandaloneToolCallsDetailed(text, toolNames) {
|
||||
if (shouldSkipToolCallParsingForCodeFenceExample(trimmed)) {
|
||||
return result;
|
||||
}
|
||||
const candidates = buildToolCallCandidates(trimmed);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
if (!isLikelyJSONToolPayloadCandidate(c)) {
|
||||
continue;
|
||||
}
|
||||
parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length === 0) {
|
||||
continue;
|
||||
}
|
||||
result.sawToolCallSyntax = true;
|
||||
const filteredJSON = filterToolCallsDetailed(parsed, toolNames);
|
||||
result.calls = filteredJSON.calls;
|
||||
result.rejectedToolNames = filteredJSON.rejectedToolNames;
|
||||
result.rejectedByPolicy = filteredJSON.rejectedToolNames.length > 0 && filteredJSON.calls.length === 0;
|
||||
return result;
|
||||
}
|
||||
for (const c of candidates) {
|
||||
parsed = parseMarkupToolCalls(c);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(c);
|
||||
}
|
||||
if (parsed.length > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// XML markup parsing only.
|
||||
const parsed = parseMarkupToolCalls(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
result.sawToolCallSyntax = true;
|
||||
@@ -183,41 +111,9 @@ function filterToolCallsDetailed(parsed, toolNames) {
|
||||
return { calls, rejectedToolNames: [] };
|
||||
}
|
||||
|
||||
function resolveAllowedToolName(name, allowed, allowedCanonical) {
|
||||
const normalizedName = toStringSafe(name).trim();
|
||||
if (!normalizedName) {
|
||||
return '';
|
||||
}
|
||||
if (allowed.has(normalizedName)) {
|
||||
return normalizedName;
|
||||
}
|
||||
const lower = normalizedName.toLowerCase();
|
||||
if (allowedCanonical.has(lower)) {
|
||||
return allowedCanonical.get(lower);
|
||||
}
|
||||
const idx = lower.lastIndexOf('.');
|
||||
if (idx >= 0 && idx < lower.length - 1) {
|
||||
const tail = lower.slice(idx + 1);
|
||||
if (allowedCanonical.has(tail)) {
|
||||
return allowedCanonical.get(tail);
|
||||
}
|
||||
}
|
||||
const loose = lower.replace(TOOL_NAME_LOOSE_PATTERN, '');
|
||||
if (!loose) {
|
||||
return '';
|
||||
}
|
||||
for (const [candidateLower, canonical] of allowedCanonical.entries()) {
|
||||
if (candidateLower.replace(TOOL_NAME_LOOSE_PATTERN, '') === loose) {
|
||||
return canonical;
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
function looksLikeToolCallSyntax(text) {
|
||||
const lower = toStringSafe(text).toLowerCase();
|
||||
return TOOL_SEGMENT_KEYWORDS.some((kw) => lower.includes(kw))
|
||||
|| TOOL_MARKUP_PREFIXES.some((prefix) => lower.includes(prefix));
|
||||
return TOOL_MARKUP_PREFIXES.some((prefix) => lower.includes(prefix));
|
||||
}
|
||||
|
||||
function shouldSkipToolCallParsingForCodeFenceExample(text) {
|
||||
@@ -228,21 +124,6 @@ function shouldSkipToolCallParsingForCodeFenceExample(text) {
|
||||
return !looksLikeToolCallSyntax(stripped);
|
||||
}
|
||||
|
||||
function isLikelyJSONToolPayloadCandidate(text) {
|
||||
const trimmed = toStringSafe(text).trim();
|
||||
if (!trimmed) {
|
||||
return false;
|
||||
}
|
||||
if (!(trimmed.startsWith('{') || trimmed.startsWith('['))) {
|
||||
return false;
|
||||
}
|
||||
const lower = trimmed.toLowerCase();
|
||||
return lower.includes('tool_calls')
|
||||
|| lower.includes('"function"')
|
||||
|| lower.includes('functioncall')
|
||||
|| lower.includes('"tool_use"');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
extractToolNames,
|
||||
parseToolCalls,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?(tool_call|function_call|invoke)\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||
const TOOL_CALL_MARKUP_SELFCLOSE_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)\/>/gi;
|
||||
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||
@@ -20,14 +19,12 @@ const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
|
||||
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
|
||||
];
|
||||
const TEXT_KV_NAME_PATTERN = /function\.name:\s*([a-zA-Z0-9_.-]+)/gi;
|
||||
const CDATA_PATTERN = /^<!\[CDATA\[([\s\S]*?)]]>$/i;
|
||||
const HTML_ENTITIES_PATTERN = /&[a-z0-9#]+;/gi;
|
||||
|
||||
const {
|
||||
toStringSafe,
|
||||
} = require('./state');
|
||||
const {
|
||||
extractJSONObjectFrom,
|
||||
} = require('./jsonscan');
|
||||
|
||||
function stripFencedCodeBlocks(text) {
|
||||
const t = typeof text === 'string' ? text : '';
|
||||
@@ -37,138 +34,6 @@ function stripFencedCodeBlocks(text) {
|
||||
return t.replace(/```[\s\S]*?```/g, ' ');
|
||||
}
|
||||
|
||||
function buildToolCallCandidates(text) {
|
||||
const trimmed = toStringSafe(text);
|
||||
const candidates = [trimmed];
|
||||
|
||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
||||
for (const block of fenced) {
|
||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||
if (m && m[1]) {
|
||||
candidates.push(toStringSafe(m[1]));
|
||||
}
|
||||
}
|
||||
|
||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
||||
candidates.push(toStringSafe(candidate));
|
||||
}
|
||||
|
||||
const first = trimmed.indexOf('{');
|
||||
const last = trimmed.lastIndexOf('}');
|
||||
if (first >= 0 && last > first) {
|
||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
||||
}
|
||||
const firstArr = trimmed.indexOf('[');
|
||||
const lastArr = trimmed.lastIndexOf(']');
|
||||
if (firstArr >= 0 && lastArr > firstArr) {
|
||||
candidates.push(toStringSafe(trimmed.slice(firstArr, lastArr + 1)));
|
||||
}
|
||||
|
||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
||||
if (m && m[1]) {
|
||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
||||
}
|
||||
|
||||
return [...new Set(candidates.filter(Boolean))];
|
||||
}
|
||||
|
||||
function extractToolCallObjects(text) {
|
||||
const raw = toStringSafe(text);
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
const out = [];
|
||||
let offset = 0;
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const idxToolCalls = lower.indexOf('tool_calls', offset);
|
||||
const idxFunction = lower.indexOf('"function"', offset);
|
||||
const idxFunctionCall = lower.indexOf('functioncall', offset);
|
||||
const idxToolUse = lower.indexOf('"tool_use"', offset);
|
||||
let idx = -1;
|
||||
let matched = '';
|
||||
if (idxToolCalls >= 0 && (idxFunction < 0 || idxToolCalls <= idxFunction)) {
|
||||
idx = idxToolCalls;
|
||||
matched = 'tool_calls';
|
||||
} else if (idxFunction >= 0) {
|
||||
idx = idxFunction;
|
||||
matched = '"function"';
|
||||
}
|
||||
if (idxFunctionCall >= 0 && (idx < 0 || idxFunctionCall < idx)) {
|
||||
idx = idxFunctionCall;
|
||||
matched = 'functioncall';
|
||||
}
|
||||
if (idxToolUse >= 0 && (idx < 0 || idxToolUse < idx)) {
|
||||
idx = idxToolUse;
|
||||
matched = '"tool_use"';
|
||||
}
|
||||
if (idx < 0) {
|
||||
break;
|
||||
}
|
||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
||||
while (start >= 0) {
|
||||
const obj = extractJSONObjectFrom(raw, start);
|
||||
if (obj.ok) {
|
||||
out.push(raw.slice(start, obj.end).trim());
|
||||
// Ensure forward progress even when the matched keyword is outside
|
||||
// the extracted JSON object (e.g. closing XML wrapper tags containing
|
||||
// "tool_calls" after an earlier JSON arguments object).
|
||||
offset = Math.max(obj.end, idx + matched.length);
|
||||
idx = -1;
|
||||
break;
|
||||
}
|
||||
start = raw.slice(0, start).lastIndexOf('{');
|
||||
}
|
||||
if (idx >= 0) {
|
||||
offset = idx + matched.length;
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallsPayload(payload) {
|
||||
let decoded;
|
||||
try {
|
||||
decoded = JSON.parse(payload);
|
||||
} catch (_err) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (Array.isArray(decoded)) {
|
||||
return parseToolCallList(decoded);
|
||||
}
|
||||
if (!decoded || typeof decoded !== 'object') {
|
||||
return [];
|
||||
}
|
||||
if (decoded.tool_calls) {
|
||||
if (isLikelyChatMessageEnvelope(decoded)) {
|
||||
return [];
|
||||
}
|
||||
return parseToolCallList(decoded.tool_calls);
|
||||
}
|
||||
|
||||
const one = parseToolCallItem(decoded);
|
||||
return one ? [one] : [];
|
||||
}
|
||||
|
||||
function isLikelyChatMessageEnvelope(value) {
|
||||
if (!value || typeof value !== 'object' || Array.isArray(value)) {
|
||||
return false;
|
||||
}
|
||||
if (!Object.prototype.hasOwnProperty.call(value, 'tool_calls')) {
|
||||
return false;
|
||||
}
|
||||
const role = toStringSafe(value.role).trim().toLowerCase();
|
||||
if (role === 'assistant' || role === 'tool' || role === 'user' || role === 'system') {
|
||||
return true;
|
||||
}
|
||||
return Object.prototype.hasOwnProperty.call(value, 'tool_call_id')
|
||||
|| Object.prototype.hasOwnProperty.call(value, 'content');
|
||||
}
|
||||
|
||||
function parseMarkupToolCalls(text) {
|
||||
const raw = toStringSafe(text).trim();
|
||||
if (!raw) {
|
||||
@@ -190,51 +55,20 @@ function parseMarkupToolCalls(text) {
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseTextKVToolCalls(text) {
|
||||
const raw = toStringSafe(text);
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
const matches = [...raw.matchAll(TEXT_KV_NAME_PATTERN)];
|
||||
if (matches.length === 0) {
|
||||
return out;
|
||||
}
|
||||
for (let i = 0; i < matches.length; i += 1) {
|
||||
const match = matches[i];
|
||||
const name = toStringSafe(match[1]).trim();
|
||||
if (!name) {
|
||||
continue;
|
||||
}
|
||||
const nameEnd = match.index + toStringSafe(match[0]).length;
|
||||
const searchEnd = i + 1 < matches.length ? matches[i + 1].index : raw.length;
|
||||
const searchArea = raw.slice(nameEnd, searchEnd);
|
||||
const argIdx = searchArea.indexOf('function.arguments:');
|
||||
if (argIdx < 0) {
|
||||
continue;
|
||||
}
|
||||
const argStart = nameEnd + argIdx + 'function.arguments:'.length;
|
||||
const bracePos = raw.slice(argStart, searchEnd).indexOf('{');
|
||||
if (bracePos < 0) {
|
||||
continue;
|
||||
}
|
||||
const objStart = argStart + bracePos;
|
||||
const obj = extractJSONObjectFrom(raw, objStart);
|
||||
if (!obj.ok) {
|
||||
continue;
|
||||
}
|
||||
out.push({
|
||||
name,
|
||||
input: parseToolCallInput(raw.slice(objStart, obj.end)),
|
||||
});
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseMarkupSingleToolCall(attrs, inner) {
|
||||
const embedded = parseToolCallsPayload(inner);
|
||||
if (embedded.length > 0) {
|
||||
return embedded[0];
|
||||
// Try inline JSON parse for the inner content.
|
||||
if (inner) {
|
||||
try {
|
||||
const decoded = JSON.parse(inner);
|
||||
if (decoded && typeof decoded === 'object' && !Array.isArray(decoded) && decoded.name) {
|
||||
return {
|
||||
name: toStringSafe(decoded.name),
|
||||
input: decoded.input && typeof decoded.input === 'object' && !Array.isArray(decoded.input) ? decoded.input : {},
|
||||
};
|
||||
}
|
||||
} catch (_err) {
|
||||
// Not JSON, continue with markup parsing.
|
||||
}
|
||||
}
|
||||
let name = '';
|
||||
const attrMatch = attrs.match(TOOL_CALL_MARKUP_ATTR_PATTERN);
|
||||
@@ -242,7 +76,7 @@ function parseMarkupSingleToolCall(attrs, inner) {
|
||||
name = toStringSafe(attrMatch[2]).trim();
|
||||
}
|
||||
if (!name) {
|
||||
name = stripTagText(findMarkupTagValue(inner, TOOL_CALL_MARKUP_NAME_PATTERNS));
|
||||
name = extractRawTagValue(findMarkupTagValue(inner, TOOL_CALL_MARKUP_NAME_PATTERNS));
|
||||
}
|
||||
if (!name) {
|
||||
return null;
|
||||
@@ -266,15 +100,21 @@ function parseMarkupInput(raw) {
|
||||
if (!s) {
|
||||
return {};
|
||||
}
|
||||
const parsed = parseToolCallInput(s);
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) {
|
||||
return parsed;
|
||||
}
|
||||
// Prioritize XML-style KV tags (e.g., <arg>val</arg>)
|
||||
const kv = parseMarkupKVObject(s);
|
||||
if (Object.keys(kv).length > 0) {
|
||||
return kv;
|
||||
}
|
||||
return { _raw: stripTagText(s) };
|
||||
|
||||
// Fallback to JSON parsing
|
||||
const parsed = parseToolCallInput(s);
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
if (Object.keys(parsed).length > 0) {
|
||||
return parsed;
|
||||
}
|
||||
}
|
||||
|
||||
return { _raw: extractRawTagValue(s) };
|
||||
}
|
||||
|
||||
function parseMarkupKVObject(text) {
|
||||
@@ -288,19 +128,65 @@ function parseMarkupKVObject(text) {
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
const valueRaw = stripTagText(m[2]);
|
||||
if (!valueRaw) {
|
||||
const value = parseMarkupValue(m[2]);
|
||||
if (value === undefined || value === null) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
out[key] = JSON.parse(valueRaw);
|
||||
} catch (_err) {
|
||||
out[key] = valueRaw;
|
||||
}
|
||||
appendMarkupValue(out, key, value);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseMarkupValue(raw) {
|
||||
const s = toStringSafe(extractRawTagValue(raw)).trim();
|
||||
if (!s) {
|
||||
return '';
|
||||
}
|
||||
|
||||
if (s.includes('<') && s.includes('>')) {
|
||||
const nested = parseMarkupInput(s);
|
||||
if (nested && typeof nested === 'object' && !Array.isArray(nested)) {
|
||||
if (isOnlyRawValue(nested)) {
|
||||
return toStringSafe(nested._raw);
|
||||
}
|
||||
return nested;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(s);
|
||||
} catch (_err) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
function extractRawTagValue(inner) {
|
||||
const s = toStringSafe(inner).trim();
|
||||
if (!s) {
|
||||
return '';
|
||||
}
|
||||
|
||||
// 1. Check for CDATA
|
||||
const cdataMatch = s.match(CDATA_PATTERN);
|
||||
if (cdataMatch && cdataMatch[1] !== undefined) {
|
||||
return cdataMatch[1];
|
||||
}
|
||||
|
||||
// 2. Fallback to unescaping standard HTML entities
|
||||
// Note: we avoid broad tag stripping here to preserve user content (like < symbols in code)
|
||||
return unescapeHtml(inner);
|
||||
}
|
||||
|
||||
function unescapeHtml(safe) {
|
||||
if (!safe) return '';
|
||||
return safe.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, "'")
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
function stripTagText(text) {
|
||||
return toStringSafe(text).replace(/<[^>]+>/g, ' ').trim();
|
||||
}
|
||||
@@ -309,80 +195,13 @@ function findMarkupTagValue(text, patterns) {
|
||||
const source = toStringSafe(text);
|
||||
for (const p of patterns) {
|
||||
const m = source.match(p);
|
||||
if (m && m[1]) {
|
||||
if (m && m[1] !== undefined) {
|
||||
return toStringSafe(m[1]);
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
function parseToolCallList(v) {
|
||||
if (!Array.isArray(v)) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const item of v) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const one = parseToolCallItem(item);
|
||||
if (one) {
|
||||
out.push(one);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallItem(m) {
|
||||
let name = toStringSafe(m.name);
|
||||
let inputRaw = m.input;
|
||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||
const fnCall = m.functionCall && typeof m.functionCall === 'object' ? m.functionCall : null;
|
||||
if (fnCall) {
|
||||
if (!name) {
|
||||
name = toStringSafe(fnCall.name);
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fnCall, 'args')) {
|
||||
inputRaw = fnCall.args;
|
||||
hasInput = true;
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fnCall, 'arguments')) {
|
||||
inputRaw = fnCall.arguments;
|
||||
hasInput = true;
|
||||
}
|
||||
}
|
||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||
|
||||
if (fn) {
|
||||
if (!name) {
|
||||
name = toStringSafe(fn.name);
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
||||
inputRaw = fn.arguments;
|
||||
hasInput = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasInput) {
|
||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
||||
inputRaw = m[k];
|
||||
hasInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
name,
|
||||
input: parseToolCallInput(inputRaw),
|
||||
};
|
||||
}
|
||||
|
||||
function parseToolCallInput(v) {
|
||||
if (v == null) {
|
||||
return {};
|
||||
@@ -416,10 +235,28 @@ function parseToolCallInput(v) {
|
||||
return {};
|
||||
}
|
||||
|
||||
function appendMarkupValue(out, key, value) {
|
||||
if (Object.prototype.hasOwnProperty.call(out, key)) {
|
||||
const current = out[key];
|
||||
if (Array.isArray(current)) {
|
||||
current.push(value);
|
||||
return;
|
||||
}
|
||||
out[key] = [current, value];
|
||||
return;
|
||||
}
|
||||
out[key] = value;
|
||||
}
|
||||
|
||||
function isOnlyRawValue(obj) {
|
||||
if (!obj || typeof obj !== 'object' || Array.isArray(obj)) {
|
||||
return false;
|
||||
}
|
||||
const keys = Object.keys(obj);
|
||||
return keys.length === 1 && keys[0] === '_raw';
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
stripFencedCodeBlocks,
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
parseMarkupToolCalls,
|
||||
parseTextKVToolCalls,
|
||||
};
|
||||
|
||||
@@ -42,8 +42,8 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
|
||||
suffix: trimmedFence.suffix,
|
||||
};
|
||||
}
|
||||
// XML tool syntax but failed to parse — consume to avoid leak.
|
||||
return { ready: true, prefix: prefixPart, calls: [], suffix: suffixPart };
|
||||
// If this block failed to become a tool call, pass it through as text.
|
||||
return { ready: true, prefix: prefixPart + xmlBlock, calls: [], suffix: suffixPart };
|
||||
}
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
@@ -79,22 +79,8 @@ function findPartialXMLToolTagStart(s) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
function looksLikeXMLToolTagFragment(s) {
|
||||
const trimmed = (s || '').trim();
|
||||
if (!trimmed) return false;
|
||||
const lower = trimmed.toLowerCase();
|
||||
const fragments = [
|
||||
'tool_calls>', 'tool_call>', '/tool_calls>', '/tool_call>',
|
||||
'function_calls>', 'function_call>', '/function_calls>', '/function_call>',
|
||||
'invoke>', '/invoke>', 'tool_use>', '/tool_use>',
|
||||
'tool_name>', '/tool_name>', 'parameters>', '/parameters>',
|
||||
];
|
||||
return fragments.some(f => lower.includes(f));
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
consumeXMLToolCapture,
|
||||
hasOpenXMLToolTag,
|
||||
findPartialXMLToolTagStart,
|
||||
looksLikeXMLToolTagFragment,
|
||||
};
|
||||
|
||||
@@ -4,18 +4,14 @@ const {
|
||||
noteText,
|
||||
insideCodeFenceWithState,
|
||||
} = require('./state');
|
||||
const { parseStandaloneToolCallsDetailed } = require('./parse');
|
||||
const { extractJSONObjectFrom, trimWrappingJSONFence } = require('./jsonscan');
|
||||
const { trimWrappingJSONFence } = require('./jsonscan');
|
||||
const {
|
||||
TOOL_SEGMENT_KEYWORDS,
|
||||
XML_TOOL_SEGMENT_TAGS,
|
||||
earliestKeywordIndex,
|
||||
} = require('./tool-keywords');
|
||||
const {
|
||||
consumeXMLToolCapture: consumeXMLToolCaptureImpl,
|
||||
hasOpenXMLToolTag,
|
||||
findPartialXMLToolTagStart,
|
||||
looksLikeXMLToolTagFragment,
|
||||
} = require('./sieve-xml');
|
||||
function processToolSieveChunk(state, chunk, toolNames) {
|
||||
if (!state) {
|
||||
@@ -80,7 +76,7 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
resetIncrementalToolState(state);
|
||||
continue;
|
||||
}
|
||||
const [safe, hold] = splitSafeContentForToolDetection(pending);
|
||||
const [safe, hold] = splitSafeContentForToolDetection(state, pending);
|
||||
if (!safe) {
|
||||
break;
|
||||
}
|
||||
@@ -117,54 +113,38 @@ function flushToolSieve(state, toolNames) {
|
||||
}
|
||||
} else if (state.capture) {
|
||||
const content = state.capture;
|
||||
if (!hasOpenXMLToolTag(content) && !looksLikeXMLToolTagFragment(content)) {
|
||||
noteText(state, content);
|
||||
events.push({ type: 'text', text: content });
|
||||
}
|
||||
noteText(state, content);
|
||||
events.push({ type: 'text', text: content });
|
||||
}
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
}
|
||||
if (state.pending) {
|
||||
if (!hasOpenXMLToolTag(state.pending) && !looksLikeXMLToolTagFragment(state.pending)) {
|
||||
noteText(state, state.pending);
|
||||
events.push({ type: 'text', text: state.pending });
|
||||
}
|
||||
noteText(state, state.pending);
|
||||
events.push({ type: 'text', text: state.pending });
|
||||
state.pending = '';
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
function splitSafeContentForToolDetection(s) {
|
||||
function splitSafeContentForToolDetection(state, s) {
|
||||
const text = s || '';
|
||||
if (!text) {
|
||||
return ['', ''];
|
||||
}
|
||||
const suspiciousStart = findSuspiciousPrefixStart(text);
|
||||
if (suspiciousStart < 0) {
|
||||
return [text, ''];
|
||||
}
|
||||
if (suspiciousStart > 0) {
|
||||
return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)];
|
||||
}
|
||||
return ['', text];
|
||||
}
|
||||
|
||||
function findSuspiciousPrefixStart(s) {
|
||||
let start = -1;
|
||||
for (const needle of ['{', '[', '```']) {
|
||||
const idx = s.lastIndexOf(needle);
|
||||
if (idx > start) {
|
||||
start = idx;
|
||||
// Only hold back partial XML tool tags.
|
||||
const xmlIdx = findPartialXMLToolTagStart(text);
|
||||
if (xmlIdx >= 0) {
|
||||
if (insideCodeFenceWithState(state, text.slice(0, xmlIdx))) {
|
||||
return [text, ''];
|
||||
}
|
||||
if (xmlIdx > 0) {
|
||||
return [text.slice(0, xmlIdx), text.slice(xmlIdx)];
|
||||
}
|
||||
return ['', text];
|
||||
}
|
||||
// Also check for partial XML tool tag at end of string.
|
||||
const xmlIdx = findPartialXMLToolTagStart(s);
|
||||
if (xmlIdx >= 0 && xmlIdx > start) {
|
||||
start = xmlIdx;
|
||||
}
|
||||
return start;
|
||||
return [text, ''];
|
||||
}
|
||||
|
||||
function findToolSegmentStart(state, s) {
|
||||
@@ -174,39 +154,23 @@ function findToolSegmentStart(state, s) {
|
||||
const lower = s.toLowerCase();
|
||||
let offset = 0;
|
||||
while (true) {
|
||||
// Check JSON keywords.
|
||||
let { index: bestKeyIdx, keyword: matchedKeyword } = earliestKeywordIndex(lower, TOOL_SEGMENT_KEYWORDS, offset);
|
||||
// Also check XML tool tags.
|
||||
// Only check XML tool tags.
|
||||
let bestIdx = -1;
|
||||
let matchedTag = '';
|
||||
for (const tag of XML_TOOL_SEGMENT_TAGS) {
|
||||
const idx = lower.indexOf(tag, offset);
|
||||
if (idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx)) {
|
||||
bestKeyIdx = idx;
|
||||
matchedKeyword = tag;
|
||||
if (idx >= 0 && (bestIdx < 0 || idx < bestIdx)) {
|
||||
bestIdx = idx;
|
||||
matchedTag = tag;
|
||||
}
|
||||
}
|
||||
if (bestKeyIdx < 0) {
|
||||
if (bestIdx < 0) {
|
||||
return -1;
|
||||
}
|
||||
// For XML tags, the '<' is itself the segment start.
|
||||
if (s[bestKeyIdx] === '<') {
|
||||
if (!insideCodeFenceWithState(state, s.slice(0, bestKeyIdx))) {
|
||||
return bestKeyIdx;
|
||||
}
|
||||
offset = bestKeyIdx + matchedKeyword.length;
|
||||
continue;
|
||||
if (!insideCodeFenceWithState(state, s.slice(0, bestIdx))) {
|
||||
return bestIdx;
|
||||
}
|
||||
const keyIdx = bestKeyIdx;
|
||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||
let candidateStart = start >= 0 ? start : keyIdx;
|
||||
// If the keyword matched inside an XML tag (e.g. "tool_calls" in "<tool_calls>"),
|
||||
// back up past the '<' to capture the full tag.
|
||||
if (candidateStart > 0 && s[candidateStart - 1] === '<') {
|
||||
candidateStart--;
|
||||
}
|
||||
if (!insideCodeFenceWithState(state, s.slice(0, candidateStart))) {
|
||||
return candidateStart;
|
||||
}
|
||||
offset = keyIdx + matchedKeyword.length;
|
||||
offset = bestIdx + matchedTag.length;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,7 +180,7 @@ function consumeToolCapture(state, toolNames) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
|
||||
// Try XML tool call extraction first.
|
||||
// XML-only tool call extraction.
|
||||
const xmlResult = consumeXMLToolCaptureImpl(captured, toolNames, trimWrappingJSONFence);
|
||||
if (xmlResult.ready) {
|
||||
return xmlResult;
|
||||
@@ -226,50 +190,12 @@ function consumeToolCapture(state, toolNames) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
|
||||
const lower = captured.toLowerCase();
|
||||
const { index: keyIdx } = earliestKeywordIndex(lower, TOOL_SEGMENT_KEYWORDS);
|
||||
if (keyIdx < 0) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
||||
const actualStart = start >= 0 ? start : keyIdx;
|
||||
const obj = extractJSONObjectFrom(captured, actualStart);
|
||||
if (!obj.ok) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const prefixPart = captured.slice(0, actualStart);
|
||||
const suffixPart = captured.slice(obj.end);
|
||||
if (insideCodeFenceWithState(state, prefixPart)) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
const parsed = parseStandaloneToolCallsDetailed(captured.slice(actualStart, obj.end), toolNames);
|
||||
if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) {
|
||||
if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: [],
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
const trimmedFence = trimWrappingJSONFence(prefixPart, suffixPart);
|
||||
// No XML tool tags detected — release captured content as text.
|
||||
return {
|
||||
ready: true,
|
||||
prefix: trimmedFence.prefix,
|
||||
calls: parsed.calls,
|
||||
suffix: trimmedFence.suffix,
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
'use strict';
|
||||
|
||||
// Keep in sync with Go toolSieveContextTailLimit.
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 2048;
|
||||
|
||||
function createToolSieveState() {
|
||||
return {
|
||||
pending: '',
|
||||
capture: '',
|
||||
capturing: false,
|
||||
recentTextTail: '',
|
||||
codeFenceStack: [],
|
||||
codeFencePendingTicks: 0,
|
||||
codeFenceLineStart: true,
|
||||
@@ -39,20 +35,6 @@ function noteText(state, text) {
|
||||
return;
|
||||
}
|
||||
updateCodeFenceState(state, text);
|
||||
state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
|
||||
}
|
||||
|
||||
function appendTail(prev, next, max) {
|
||||
const left = typeof prev === 'string' ? prev : '';
|
||||
const right = typeof next === 'string' ? next : '';
|
||||
if (!Number.isFinite(max) || max <= 0) {
|
||||
return '';
|
||||
}
|
||||
const combined = left + right;
|
||||
if (combined.length <= max) {
|
||||
return combined;
|
||||
}
|
||||
return combined.slice(combined.length - max);
|
||||
}
|
||||
|
||||
function looksLikeToolExampleContext(text) {
|
||||
@@ -171,11 +153,9 @@ function toStringSafe(v) {
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
|
||||
createToolSieveState,
|
||||
resetIncrementalToolState,
|
||||
noteText,
|
||||
appendTail,
|
||||
looksLikeToolExampleContext,
|
||||
insideCodeFence,
|
||||
insideCodeFenceWithState,
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_SEGMENT_KEYWORDS = [
|
||||
'tool_calls',
|
||||
'"function"',
|
||||
'function.name:',
|
||||
'functioncall',
|
||||
'"tool_use"',
|
||||
];
|
||||
|
||||
const XML_TOOL_SEGMENT_TAGS = [
|
||||
'<tool_calls>', '<tool_calls\n', '<tool_call>', '<tool_call\n',
|
||||
'<tool_calls>', '<tool_calls\n', '<tool_calls ', '<tool_call>', '<tool_call\n', '<tool_call ',
|
||||
'<invoke ', '<invoke>', '<function_call', '<function_calls', '<tool_use>',
|
||||
];
|
||||
|
||||
@@ -21,26 +13,9 @@ const XML_TOOL_CLOSING_TAGS = [
|
||||
'</tool_calls>', '</tool_call>', '</invoke>', '</function_call>', '</function_calls>', '</tool_use>',
|
||||
];
|
||||
|
||||
function earliestKeywordIndex(text, keywords = TOOL_SEGMENT_KEYWORDS, offset = 0) {
|
||||
if (!text) {
|
||||
return { index: -1, keyword: '' };
|
||||
}
|
||||
let index = -1;
|
||||
let keyword = '';
|
||||
for (const kw of keywords) {
|
||||
const candidate = text.indexOf(kw, offset);
|
||||
if (candidate >= 0 && (index < 0 || candidate < index)) {
|
||||
index = candidate;
|
||||
keyword = kw;
|
||||
}
|
||||
}
|
||||
return { index, keyword };
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
TOOL_SEGMENT_KEYWORDS,
|
||||
XML_TOOL_SEGMENT_TAGS,
|
||||
XML_TOOL_OPENING_TAGS,
|
||||
XML_TOOL_CLOSING_TAGS,
|
||||
earliestKeywordIndex,
|
||||
};
|
||||
|
||||
|
||||
@@ -18,8 +18,6 @@ const (
|
||||
endSentenceMarker = "<|end▁of▁sentence|>"
|
||||
endToolResultsMarker = "<|end▁of▁toolresults|>"
|
||||
endInstructionsMarker = "<|end▁of▁instructions|>"
|
||||
openThinkMarker = "<think>"
|
||||
closeThinkMarker = "</think>"
|
||||
)
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
@@ -32,6 +30,11 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
Text string
|
||||
}
|
||||
processed := make([]block, 0, len(messages))
|
||||
if thinkingEnabled {
|
||||
if instruction := buildConversationContinuityInstructions(thinkingEnabled); strings.TrimSpace(instruction) != "" {
|
||||
processed = append(processed, block{Role: "system", Text: instruction})
|
||||
}
|
||||
}
|
||||
for _, m := range messages {
|
||||
role, _ := m["role"].(string)
|
||||
text := NormalizeContent(m["content"])
|
||||
@@ -55,7 +58,7 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
lastRole = m.Role
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, formatRoleBlock(assistantMarker, closeThinkMarker+m.Text, endSentenceMarker))
|
||||
parts = append(parts, formatRoleBlock(assistantMarker, m.Text, endSentenceMarker))
|
||||
case "tool":
|
||||
if strings.TrimSpace(m.Text) != "" {
|
||||
parts = append(parts, formatRoleBlock(toolMarker, m.Text, endToolResultsMarker))
|
||||
@@ -65,7 +68,7 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
parts = append(parts, formatRoleBlock(systemMarker, text, endInstructionsMarker))
|
||||
}
|
||||
case "user":
|
||||
parts = append(parts, formatRoleBlock(userMarker, m.Text, endSentenceMarker))
|
||||
parts = append(parts, formatRoleBlock(userMarker, m.Text, ""))
|
||||
default:
|
||||
if strings.TrimSpace(m.Text) != "" {
|
||||
parts = append(parts, m.Text)
|
||||
@@ -73,25 +76,34 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
}
|
||||
}
|
||||
if lastRole != "assistant" {
|
||||
thinkPrefix := closeThinkMarker
|
||||
if thinkingEnabled {
|
||||
thinkPrefix = openThinkMarker
|
||||
}
|
||||
parts = append(parts, assistantMarker+thinkPrefix)
|
||||
parts = append(parts, assistantMarker)
|
||||
}
|
||||
out := strings.Join(parts, "\n\n")
|
||||
out := strings.Join(parts, "")
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
}
|
||||
|
||||
// DeepSeek-style turn suffixes stay attached to the same block as the role content.
|
||||
// formatRoleBlock produces a single concatenated block: marker + text + endMarker.
|
||||
// No whitespace is inserted between marker and text so role boundaries stay
|
||||
// compact and predictable for downstream parsers.
|
||||
func formatRoleBlock(marker, text, endMarker string) string {
|
||||
out := marker + "\n" + text
|
||||
out := marker + text
|
||||
if strings.TrimSpace(endMarker) != "" {
|
||||
out += endMarker
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildConversationContinuityInstructions(thinkingEnabled bool) string {
|
||||
lines := []string{
|
||||
"Continue the conversation from the full prior context and the latest tool results.",
|
||||
"Treat earlier messages as binding context; answer the user's current request as a continuation, not a restart.",
|
||||
}
|
||||
if thinkingEnabled {
|
||||
lines = append(lines, "Keep reasoning internal. Do not leave the final user-facing answer only in reasoning; always provide the answer in visible assistant content.")
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func NormalizeContent(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
|
||||
@@ -35,15 +35,18 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
|
||||
if !strings.HasPrefix(got, "<|begin▁of▁sentence|>") {
|
||||
t.Fatalf("expected begin-of-sentence marker, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|System|>\nSystem rule<|end▁of▁instructions|>") {
|
||||
if !strings.Contains(got, "<|System|>System rule<|end▁of▁instructions|>") {
|
||||
t.Fatalf("expected system instructions suffix, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|User|>\nQuestion<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected user sentence suffix, got %q", got)
|
||||
if !strings.Contains(got, "<|User|>Question") {
|
||||
t.Fatalf("expected user question, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|Assistant|>\n</think>Answer<|end▁of▁sentence|>") {
|
||||
if !strings.Contains(got, "<|Assistant|>Answer<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected assistant sentence suffix, got %q", got)
|
||||
}
|
||||
if strings.Contains(got, "<think>") || strings.Contains(got, "</think>") {
|
||||
t.Fatalf("did not expect think tags in prompt, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
@@ -55,10 +58,23 @@ func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareWithThinkingEndsWithOpenThink(t *testing.T) {
|
||||
func TestMessagesPrepareWithThinkingAddsContinuityContract(t *testing.T) {
|
||||
messages := []map[string]any{{"role": "user", "content": "Question"}}
|
||||
got := MessagesPrepareWithThinking(messages, true)
|
||||
if !strings.HasSuffix(got, "<|Assistant|><think>") {
|
||||
t.Fatalf("expected thinking suffix, got %q", got)
|
||||
gotThinking := MessagesPrepareWithThinking(messages, true)
|
||||
gotPlain := MessagesPrepareWithThinking(messages, false)
|
||||
if gotThinking == gotPlain {
|
||||
t.Fatalf("expected thinking-enabled prompt to include extra continuity instructions")
|
||||
}
|
||||
if !strings.HasSuffix(gotThinking, "<|Assistant|>") {
|
||||
t.Fatalf("expected assistant suffix, got %q", gotThinking)
|
||||
}
|
||||
if !strings.Contains(gotThinking, "Continue the conversation from the full prior context") {
|
||||
t.Fatalf("expected continuity instruction in thinking prompt, got %q", gotThinking)
|
||||
}
|
||||
if !strings.Contains(gotThinking, "final user-facing answer only in reasoning") {
|
||||
t.Fatalf("expected visible-answer instruction in thinking prompt, got %q", gotThinking)
|
||||
}
|
||||
if strings.Contains(gotPlain, "Continue the conversation from the full prior context") {
|
||||
t.Fatalf("did not expect thinking-only instruction in plain prompt, got %q", gotPlain)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package prompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -11,6 +14,8 @@ var promptXMLTextEscaper = strings.NewReplacer(
|
||||
">", ">",
|
||||
)
|
||||
|
||||
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
|
||||
|
||||
// FormatToolCallsForPrompt renders a tool_calls slice into the canonical
|
||||
// prompt-visible history block used across adapters.
|
||||
func FormatToolCallsForPrompt(raw any) string {
|
||||
@@ -87,12 +92,160 @@ func formatToolCallForPrompt(call map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
parameters := formatToolCallParametersForPrompt(argsRaw)
|
||||
|
||||
return " <tool_call>\n" +
|
||||
" <tool_name>" + escapeXMLText(name) + "</tool_name>\n" +
|
||||
" <parameters>" + escapeXMLText(StringifyToolCallArguments(argsRaw)) + "</parameters>\n" +
|
||||
parameters + "\n" +
|
||||
" </tool_call>"
|
||||
}
|
||||
|
||||
func formatToolCallParametersForPrompt(raw any) string {
|
||||
value := normalizePromptToolCallValue(raw)
|
||||
body, ok := renderPromptToolXMLBody(value, " ")
|
||||
if ok {
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return " <parameters></parameters>"
|
||||
}
|
||||
return " <parameters>\n" + body + "\n </parameters>"
|
||||
}
|
||||
|
||||
fallback := StringifyToolCallArguments(raw)
|
||||
if strings.TrimSpace(fallback) == "" {
|
||||
fallback = "{}"
|
||||
}
|
||||
return " <parameters><content>" + renderPromptXMLText(fallback) + "</content></parameters>"
|
||||
}
|
||||
|
||||
func normalizePromptToolCallValue(raw any) any {
|
||||
switch x := raw.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(s), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
return x
|
||||
default:
|
||||
return x
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptToolXMLBody(value any, indent string) (string, bool) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "", true
|
||||
case map[string]any:
|
||||
return renderPromptToolXMLMap(v, indent)
|
||||
case []any:
|
||||
return renderPromptToolXMLArray(v, indent)
|
||||
case string:
|
||||
return indent + "<content>" + renderPromptXMLText(v) + "</content>", true
|
||||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return indent + "<value>" + escapeXMLText(fmt.Sprint(v)) + "</value>", true
|
||||
default:
|
||||
return indent + "<value>" + renderPromptXMLText(fmt.Sprint(v)) + "</value>", true
|
||||
}
|
||||
}
|
||||
|
||||
func renderPromptToolXMLMap(m map[string]any, indent string) (string, bool) {
|
||||
if len(m) == 0 {
|
||||
return "", true
|
||||
}
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if !isValidPromptXMLName(k) {
|
||||
return "", false
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
rendered, ok := renderPromptToolXMLNode(key, m[key], indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
}
|
||||
|
||||
func renderPromptToolXMLArray(items []any, indent string) (string, bool) {
|
||||
if len(items) == 0 {
|
||||
return "", true
|
||||
}
|
||||
lines := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
rendered, ok := renderPromptToolXMLNode("item", item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
}
|
||||
|
||||
func renderPromptToolXMLNode(name string, value any, indent string) (string, bool) {
|
||||
if !isValidPromptXMLName(name) {
|
||||
return "", false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
case map[string]any:
|
||||
inner, ok := renderPromptToolXMLMap(v, indent+" ")
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(inner) == "" {
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
}
|
||||
return indent + "<" + name + ">\n" + inner + "\n" + indent + "</" + name + ">", true
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return indent + "<" + name + "></" + name + ">", true
|
||||
}
|
||||
lines := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
rendered, ok := renderPromptToolXMLNode(name, item, indent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
lines = append(lines, rendered)
|
||||
}
|
||||
return strings.Join(lines, "\n"), true
|
||||
case string:
|
||||
return indent + "<" + name + ">" + renderPromptXMLText(v) + "</" + name + ">", true
|
||||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return indent + "<" + name + ">" + escapeXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||||
default:
|
||||
return indent + "<" + name + ">" + renderPromptXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||||
}
|
||||
}
|
||||
|
||||
// renderPromptXMLText emits CDATA for every string so prompt-visible tool
|
||||
// history stays uniform and does not drift back toward ad-hoc escaping.
|
||||
func renderPromptXMLText(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(text, "]]>") {
|
||||
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
|
||||
}
|
||||
return "<![CDATA[" + text + "]]>"
|
||||
}
|
||||
|
||||
func isValidPromptXMLName(name string) bool {
|
||||
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestFormatToolCallsForPromptXML(t *testing.T) {
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty formatted tool calls")
|
||||
}
|
||||
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>{\"query\":\"latest\"}</parameters>\n </tool_call>\n</tool_calls>" {
|
||||
if got != "<tool_calls>\n <tool_call>\n <tool_name>search_web</tool_name>\n <parameters>\n <query><![CDATA[latest]]></query>\n </parameters>\n </tool_call>\n</tool_calls>" {
|
||||
t.Fatalf("unexpected formatted tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -34,8 +34,24 @@ func TestFormatToolCallsForPromptEscapesXMLEntities(t *testing.T) {
|
||||
"arguments": `{"q":"a < b && c > d"}`,
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>{\"q\":\"a < b && c > d\"}</parameters>\n </tool_call>\n</tool_calls>"
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>\n <q><![CDATA[a < b && c > d]]></q>\n </parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected escaped tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallsForPromptUsesCDATAForMultilineContent(t *testing.T) {
|
||||
got := FormatToolCallsForPrompt([]any{
|
||||
map[string]any{
|
||||
"name": "write_file",
|
||||
"arguments": map[string]any{
|
||||
"path": "script.sh",
|
||||
"content": "#!/bin/bash\nprintf \"hello\"\n",
|
||||
},
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>write_file</tool_name>\n <parameters>\n <content><![CDATA[#!/bin/bash\nprintf \"hello\"\n]]></content>\n <path><![CDATA[script.sh]]></path>\n </parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected multiline cdata tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +20,7 @@ import (
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/admin"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/webui"
|
||||
@@ -46,17 +50,21 @@ func NewApp() (*App, error) {
|
||||
} else {
|
||||
config.Logger.Info("[PoW] pure Go solver ready")
|
||||
}
|
||||
chatHistoryStore := chathistory.New(config.ChatHistoryPath())
|
||||
if err := chatHistoryStore.Err(); err != nil {
|
||||
config.Logger.Warn("[chat_history] unavailable", "path", chatHistoryStore.Path(), "error", err)
|
||||
}
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient, ChatHistory: chatHistoryStore}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient, OpenAI: openaiHandler}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient, OpenAI: openaiHandler, ChatHistory: chatHistoryStore}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(filteredLogger())
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(cors)
|
||||
r.Use(timeout(0))
|
||||
@@ -99,11 +107,44 @@ func timeout(d time.Duration) func(http.Handler) http.Handler {
|
||||
return middleware.Timeout(d)
|
||||
}
|
||||
|
||||
func filteredLogger() func(http.Handler) http.Handler {
|
||||
color := !isWindowsRuntime()
|
||||
base := &middleware.DefaultLogFormatter{
|
||||
Logger: log.New(os.Stdout, "", log.LstdFlags),
|
||||
NoColor: !color,
|
||||
}
|
||||
return middleware.RequestLogger(&filteredLogFormatter{base: base})
|
||||
}
|
||||
|
||||
func isWindowsRuntime() bool {
|
||||
return runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
type filteredLogFormatter struct {
|
||||
base *middleware.DefaultLogFormatter
|
||||
}
|
||||
|
||||
func (f *filteredLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
|
||||
if r != nil && r.Method == http.MethodGet {
|
||||
path := strings.TrimSpace(r.URL.Path)
|
||||
if path == "/admin/chat-history" || strings.HasPrefix(path, "/admin/chat-history/") {
|
||||
return noopLogEntry{}
|
||||
}
|
||||
}
|
||||
return f.base.NewLogEntry(r)
|
||||
}
|
||||
|
||||
type noopLogEntry struct{}
|
||||
|
||||
func (noopLogEntry) Write(_ int, _ int, _ http.Header, _ time.Duration, _ interface{}) {}
|
||||
|
||||
func (noopLogEntry) Panic(_ interface{}, _ []byte) {}
|
||||
|
||||
func cors(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Ds2-Source, X-Vercel-Protection-Bypass")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
|
||||
168
internal/sse/citation_links.go
Normal file
168
internal/sse/citation_links.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type citationLinkCollector struct {
|
||||
ordered []string
|
||||
explicitRaw map[int]string
|
||||
hasZeroIdx bool
|
||||
}
|
||||
|
||||
func newCitationLinkCollector() *citationLinkCollector {
|
||||
return &citationLinkCollector{
|
||||
explicitRaw: map[int]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) ingestChunk(chunk map[string]any) {
|
||||
if c == nil || len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
c.walkValue(chunk)
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) build() map[int]string {
|
||||
out := make(map[int]string, len(c.explicitRaw)+len(c.ordered))
|
||||
for idx, u := range c.buildNormalizedExplicit() {
|
||||
out[idx] = u
|
||||
}
|
||||
for i, u := range c.ordered {
|
||||
idx := i + 1
|
||||
if _, exists := out[idx]; !exists {
|
||||
out[idx] = u
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) buildNormalizedExplicit() map[int]string {
|
||||
out := make(map[int]string, len(c.explicitRaw))
|
||||
|
||||
// Default behavior keeps positive indices as-is (one-based payloads).
|
||||
for idx, u := range c.explicitRaw {
|
||||
if idx <= 0 || strings.TrimSpace(u) == "" {
|
||||
continue
|
||||
}
|
||||
out[idx] = u
|
||||
}
|
||||
|
||||
if !c.hasZeroIdx {
|
||||
return out
|
||||
}
|
||||
|
||||
// If zero index appears, upstream may be using zero-based indices.
|
||||
// Add shifted candidates and resolve conflicts using ordered appearance,
|
||||
// which matches visible citation marker order in response text.
|
||||
for rawIdx, u := range c.explicitRaw {
|
||||
if rawIdx < 0 || strings.TrimSpace(u) == "" {
|
||||
continue
|
||||
}
|
||||
normalized := rawIdx + 1
|
||||
existing, exists := out[normalized]
|
||||
if !exists {
|
||||
out[normalized] = u
|
||||
continue
|
||||
}
|
||||
if c.preferURLForIndex(normalized, existing, u) == u {
|
||||
out[normalized] = u
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) preferURLForIndex(idx int, current, candidate string) string {
|
||||
if idx <= 0 || idx > len(c.ordered) {
|
||||
return current
|
||||
}
|
||||
expected := c.ordered[idx-1]
|
||||
switch {
|
||||
case strings.TrimSpace(expected) == "":
|
||||
return current
|
||||
case candidate == expected && current != expected:
|
||||
return candidate
|
||||
default:
|
||||
return current
|
||||
}
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) walkValue(v any) {
|
||||
switch x := v.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
c.walkValue(item)
|
||||
}
|
||||
case map[string]any:
|
||||
c.captureURLAndIndex(x)
|
||||
for _, vv := range x {
|
||||
c.walkValue(vv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) captureURLAndIndex(m map[string]any) {
|
||||
url := strings.TrimSpace(asString(m["url"]))
|
||||
if !isWebURL(url) {
|
||||
return
|
||||
}
|
||||
c.addOrdered(url)
|
||||
|
||||
idx, hasIdx := citationIndexFromAny(m["cite_index"])
|
||||
if !hasIdx {
|
||||
return
|
||||
}
|
||||
if idx < 0 {
|
||||
return
|
||||
}
|
||||
if idx == 0 {
|
||||
c.hasZeroIdx = true
|
||||
}
|
||||
if existing, ok := c.explicitRaw[idx]; ok && strings.TrimSpace(existing) != "" {
|
||||
return
|
||||
}
|
||||
c.explicitRaw[idx] = url
|
||||
}
|
||||
|
||||
func (c *citationLinkCollector) addOrdered(url string) {
|
||||
c.ordered = append(c.ordered, url)
|
||||
}
|
||||
|
||||
func citationIndexFromAny(v any) (int, bool) {
|
||||
switch x := v.(type) {
|
||||
case int:
|
||||
return x, true
|
||||
case int32:
|
||||
return int(x), true
|
||||
case int64:
|
||||
return int(x), true
|
||||
case float32:
|
||||
return int(x), true
|
||||
case float64:
|
||||
return int(x), true
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func isWebURL(v string) bool {
|
||||
v = strings.ToLower(strings.TrimSpace(v))
|
||||
return strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://")
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
@@ -13,6 +13,7 @@ type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
ContentFilter bool
|
||||
CitationLinks map[int]string
|
||||
}
|
||||
|
||||
// CollectStream fully consumes a DeepSeek SSE response and separates
|
||||
@@ -28,11 +29,23 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
contentFilter := false
|
||||
stopped := false
|
||||
collector := newCitationLinkCollector()
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
}
|
||||
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
|
||||
chunk, done, parsed := ParseDeepSeekSSELine(line)
|
||||
if parsed && !done {
|
||||
collector.ingestChunk(chunk)
|
||||
}
|
||||
if done {
|
||||
return false
|
||||
}
|
||||
if stopped {
|
||||
return true
|
||||
}
|
||||
result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = result.NextType
|
||||
if !result.Parsed {
|
||||
@@ -42,7 +55,11 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
if result.ContentFilter {
|
||||
contentFilter = true
|
||||
}
|
||||
return false
|
||||
// Keep scanning to collect late-arriving citation metadata lines
|
||||
// that can appear after response/status=FINISHED, but stop as soon
|
||||
// as [DONE] arrives.
|
||||
stopped = true
|
||||
return true
|
||||
}
|
||||
for _, p := range result.Parts {
|
||||
if p.Type == "thinking" {
|
||||
@@ -59,5 +76,6 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
Text: text.String(),
|
||||
Thinking: thinking.String(),
|
||||
ContentFilter: contentFilter,
|
||||
CitationLinks: collector.build(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ─── CollectStream edge cases ────────────────────────────────────────
|
||||
@@ -115,6 +116,94 @@ func TestCollectStreamWithCitation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamExtractsCitationLinks(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":0},{\"url\":\"https://example.com/b\",\"cite_index\":1}]}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"结论[citation:1][citation:2]\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
|
||||
if got := result.CitationLinks[1]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation 1 link, got %q", got)
|
||||
}
|
||||
if got := result.CitationLinks[2]; got != "https://example.com/b" {
|
||||
t.Fatalf("expected citation 2 link, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamExtractsCitationLinksForSequentialZeroBasedIndices(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":0},{\"url\":\"https://example.com/b\",\"cite_index\":1},{\"url\":\"https://example.com/c\",\"cite_index\":2}]}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"结论[citation:1][citation:2][citation:3]\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
|
||||
if got := result.CitationLinks[1]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation 1 link, got %q", got)
|
||||
}
|
||||
if got := result.CitationLinks[2]; got != "https://example.com/b" {
|
||||
t.Fatalf("expected citation 2 link, got %q", got)
|
||||
}
|
||||
if got := result.CitationLinks[3]; got != "https://example.com/c" {
|
||||
t.Fatalf("expected citation 3 link, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamExtractsCitationLinksForOneBasedIndices(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":1},{\"url\":\"https://example.com/b\",\"cite_index\":2}]}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"结论[citation:1][citation:2]\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
|
||||
if got := result.CitationLinks[1]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation 1 link, got %q", got)
|
||||
}
|
||||
if got := result.CitationLinks[2]; got != "https://example.com/b" {
|
||||
t.Fatalf("expected citation 2 link, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamExtractsCitationLinksWithRepeatedURLsAndNilIndices(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":null},{\"url\":\"https://example.com/a\",\"cite_index\":null},{\"url\":\"https://example.com/b\",\"cite_index\":null}]}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"结论[citation:1][citation:2][citation:3]\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
|
||||
if got := result.CitationLinks[1]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation 1 link, got %q", got)
|
||||
}
|
||||
if got := result.CitationLinks[2]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation 2 link, got %q", got)
|
||||
}
|
||||
if got := result.CitationLinks[3]; got != "https://example.com/b" {
|
||||
t.Fatalf("expected citation 3 link, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamCollectsCitationLinksAfterFinished(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"结论[citation:1]\"}\n" +
|
||||
"data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n" +
|
||||
"data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":1}]}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"should-not-append\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "结论[citation:1]" {
|
||||
t.Fatalf("expected text to freeze after finished, got %q", result.Text)
|
||||
}
|
||||
if got := result.CitationLinks[1]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation 1 link, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamMultipleThinkingChunks(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" +
|
||||
@@ -139,6 +228,39 @@ func TestCollectStreamStatusFinished(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamStopsOnDoneAfterFinished(t *testing.T) {
|
||||
pr, pw := io.Pipe()
|
||||
defer func() { _ = pw.Close() }()
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: pr,
|
||||
}
|
||||
|
||||
resultCh := make(chan CollectResult, 1)
|
||||
go func() {
|
||||
resultCh <- CollectStream(resp, false, false)
|
||||
}()
|
||||
|
||||
_, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n")
|
||||
_, _ = io.WriteString(pw, "data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n")
|
||||
_, _ = io.WriteString(pw, "data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":1}]}\n")
|
||||
_, _ = io.WriteString(pw, "data: [DONE]\n")
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.Text != "Hello" {
|
||||
t.Fatalf("expected text to freeze at FINISHED, got %q", result.Text)
|
||||
}
|
||||
if got := result.CitationLinks[1]; got != "https://example.com/a" {
|
||||
t.Fatalf("expected citation metadata after FINISHED, got %q", got)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("CollectStream did not stop on [DONE] after FINISHED")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" +
|
||||
|
||||
@@ -3,6 +3,7 @@ package sse
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
@@ -93,6 +94,11 @@ func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, current
|
||||
if finished {
|
||||
return nil, true, newType
|
||||
}
|
||||
var transitioned bool
|
||||
parts, transitioned = splitThinkingParts(parts)
|
||||
if transitioned {
|
||||
newType = "text"
|
||||
}
|
||||
return parts, false, newType
|
||||
}
|
||||
|
||||
@@ -166,6 +172,9 @@ func updateTypeFromNestedResponse(path string, v any, newType *string) {
|
||||
func resolvePartType(path string, thinkingEnabled bool, newType string) string {
|
||||
switch {
|
||||
case path == "response/thinking_content":
|
||||
if newType == "text" {
|
||||
return "text"
|
||||
}
|
||||
return "thinking"
|
||||
case path == "response/content":
|
||||
return "text"
|
||||
@@ -244,6 +253,63 @@ func appendContentPart(parts *[]ContentPart, content, kind string) {
|
||||
*parts = append(*parts, ContentPart{Text: content, Type: kind})
|
||||
}
|
||||
|
||||
var thinkClosePattern = regexp.MustCompile(`(?i)</\s*think\s*>`)
|
||||
var thinkOpenPattern = regexp.MustCompile(`(?i)<\s*think\s*>`)
|
||||
|
||||
// splitThinkingParts detects </think> inside thinking content and
|
||||
// auto-transitions everything after it to text. This handles the
|
||||
// DeepSeek API bug where the upstream SSE keeps sending
|
||||
// reasoning_content even though the model has finished thinking.
|
||||
func splitThinkingParts(parts []ContentPart) ([]ContentPart, bool) {
|
||||
var out []ContentPart
|
||||
thinkingDone := false
|
||||
for _, p := range parts {
|
||||
if thinkingDone && p.Type == "thinking" {
|
||||
// Already transitioned — treat remaining thinking as text.
|
||||
cleaned := stripThinkTags(p.Text)
|
||||
if cleaned != "" {
|
||||
out = append(out, ContentPart{Text: cleaned, Type: "text"})
|
||||
}
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" {
|
||||
cleaned := stripThinkTags(p.Text)
|
||||
if cleaned != "" {
|
||||
out = append(out, ContentPart{Text: cleaned, Type: p.Type})
|
||||
}
|
||||
continue
|
||||
}
|
||||
loc := thinkClosePattern.FindStringIndex(p.Text)
|
||||
if loc == nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
// Split at </think>: before is still thinking, after becomes text.
|
||||
thinkingDone = true
|
||||
before := p.Text[:loc[0]]
|
||||
after := p.Text[loc[1]:]
|
||||
if before != "" {
|
||||
out = append(out, ContentPart{Text: before, Type: "thinking"})
|
||||
}
|
||||
after = stripThinkTags(after)
|
||||
if after != "" {
|
||||
out = append(out, ContentPart{Text: after, Type: "text"})
|
||||
}
|
||||
}
|
||||
if !thinkingDone {
|
||||
// Return 'out' instead of 'parts' because text parts might have been cleaned via stripThinkTags
|
||||
return out, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
// stripThinkTags removes any remaining <think> or </think> tags from text.
|
||||
func stripThinkTags(s string) string {
|
||||
s = thinkClosePattern.ReplaceAllString(s, "")
|
||||
s = thinkOpenPattern.ReplaceAllString(s, "")
|
||||
return s
|
||||
}
|
||||
|
||||
func isStatusPath(path string) bool {
|
||||
return path == "response/status" || path == "status"
|
||||
}
|
||||
|
||||
@@ -87,3 +87,79 @@ func TestParseSSEChunkForContentAfterAppendUsesUpdatedType(t *testing.T) {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentAutoTransitionsThinkClose(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "deep thoughts</think>actual answer",
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts from split, got %d: %#v", len(parts), parts)
|
||||
}
|
||||
if parts[0].Type != "thinking" || parts[0].Text != "deep thoughts" {
|
||||
t.Fatalf("first part should be thinking: %#v", parts[0])
|
||||
}
|
||||
if parts[1].Type != "text" || parts[1].Text != "actual answer" {
|
||||
t.Fatalf("second part should be text: %#v", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentStripsLeakedThinkTags(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "<think>more thoughts</think> answer",
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts)
|
||||
}
|
||||
if parts[0].Type != "thinking" || parts[0].Text != "<think>more thoughts" {
|
||||
// note: the open tag is before the split, so it remains in the thinking part.
|
||||
// that's fine, the output sanitization handles the final string.
|
||||
t.Fatalf("first part mismatch: %#v", parts[0])
|
||||
}
|
||||
if parts[1].Type != "text" || parts[1].Text != " answer" {
|
||||
t.Fatalf("second part mismatch: %#v", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentAutoTransitionsState(t *testing.T) {
|
||||
chunk1 := map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "end of thought</think>start of text",
|
||||
}
|
||||
parts1, _, nextType1 := ParseSSEChunkForContent(chunk1, true, "thinking")
|
||||
if len(parts1) != 2 || parts1[1].Type != "text" {
|
||||
t.Fatalf("expected split parts, got %#v", parts1)
|
||||
}
|
||||
if nextType1 != "text" {
|
||||
t.Fatalf("expected nextType to transition to text, got %q", nextType1)
|
||||
}
|
||||
|
||||
chunk2 := map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "more actual text sent to thinking path",
|
||||
}
|
||||
parts2, _, nextType2 := ParseSSEChunkForContent(chunk2, true, nextType1)
|
||||
if len(parts2) != 1 || parts2[0].Type != "text" {
|
||||
t.Fatalf("expected subsequent parts to be text, got %#v", parts2)
|
||||
}
|
||||
if nextType2 != "text" {
|
||||
t.Fatalf("expected nextType2 to remain text, got %q", nextType2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentStripsLeakedThinkTagsFromText(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/content", // This makes the part type "text"
|
||||
"v": "normal text <think>leaked</think> end",
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, true, "text")
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d: %#v", len(parts), parts)
|
||||
}
|
||||
if parts[0].Type != "text" || parts[0].Text != "normal text leaked end" {
|
||||
t.Fatalf("expected leaked think tag to be stripped, got %#v", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
81
internal/toolcall/regression_test.go
Normal file
81
internal/toolcall/regression_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegression_RobustXMLAndCDATA(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expected []ParsedToolCall
|
||||
}{
|
||||
{
|
||||
name: "Standard JSON parameters (Regression)",
|
||||
text: `<tool_call><tool_name>foo</tool_name><parameters>{"a": 1}</parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{Name: "foo", Input: map[string]any{"a": float64(1)}}},
|
||||
},
|
||||
{
|
||||
name: "XML tags parameters (Regression)",
|
||||
text: `<tool_call><tool_name>foo</tool_name><parameters><arg1>hello</arg1></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{Name: "foo", Input: map[string]any{"arg1": "hello"}}},
|
||||
},
|
||||
{
|
||||
name: "CDATA parameters (New Feature)",
|
||||
text: `<tool_call><tool_name>write_file</tool_name><parameters><content><![CDATA[line 1
|
||||
line 2 with <tags> and & symbols]]></content></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "write_file",
|
||||
Input: map[string]any{"content": "line 1\nline 2 with <tags> and & symbols"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Nested XML with repeated parameters (New Feature)",
|
||||
text: `<tool_call><tool_name>write_file</tool_name><parameters><path>script.sh</path><content><![CDATA[#!/bin/bash
|
||||
echo "hello"
|
||||
]]></content><item>first</item><item>second</item></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "write_file",
|
||||
Input: map[string]any{
|
||||
"path": "script.sh",
|
||||
"content": "#!/bin/bash\necho \"hello\"\n",
|
||||
"item": []any{"first", "second"},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Dirty XML with unescaped symbols (Robustness Improvement)",
|
||||
text: `<tool_call><tool_name>bash</tool_name><parameters><command>echo "hello" > out.txt && cat out.txt</command></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "bash",
|
||||
Input: map[string]any{"command": "echo \"hello\" > out.txt && cat out.txt"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Mixed JSON inside CDATA (New Hybrid Case)",
|
||||
text: `<tool_call><tool_name>foo</tool_name><parameters><![CDATA[{"json_param": "works"}]]></parameters></tool_call>`,
|
||||
expected: []ParsedToolCall{{
|
||||
Name: "foo",
|
||||
Input: map[string]any{"json_param": "works"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ParseToolCalls(tt.text, []string{"foo", "write_file", "bash"})
|
||||
if len(got) != len(tt.expected) {
|
||||
t.Fatalf("expected %d calls, got %d", len(tt.expected), len(got))
|
||||
}
|
||||
for i := range got {
|
||||
if got[i].Name != tt.expected[i].Name {
|
||||
t.Errorf("expected name %q, got %q", tt.expected[i].Name, got[i].Name)
|
||||
}
|
||||
if !reflect.DeepEqual(got[i].Input, tt.expected[i].Input) {
|
||||
t.Errorf("expected input %#v, got %#v", tt.expected[i].Input, got[i].Input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -36,45 +36,47 @@ func BuildToolCallInstructions(toolNames []string) string {
|
||||
|
||||
return `TOOL CALL FORMAT — FOLLOW EXACTLY:
|
||||
|
||||
When calling tools, emit ONLY raw XML at the very end of your response. No text before, no text after, no markdown fences.
|
||||
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>TOOL_NAME_HERE</tool_name>
|
||||
<parameters>{"key":"value"}</parameters>
|
||||
<parameters>
|
||||
<PARAMETER_NAME><![CDATA[PARAMETER_VALUE]]></PARAMETER_NAME>
|
||||
</parameters>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
RULES:
|
||||
1) When calling tools, you MUST use the <tool_calls> XML format.
|
||||
2) No text is allowed AFTER the XML block.
|
||||
3) <parameters> MUST be a single-line strict JSON object. Use double quotes.
|
||||
4) Multiple tools must be inside the same <tool_calls> root.
|
||||
5) Do NOT wrap XML in markdown fences (` + "```" + `).
|
||||
6) Do NOT invent parameters. Use only the provided schema.
|
||||
7) CRITICAL: Do NOT use native tool markers like "<|Tool|>" or "<|tool|>".
|
||||
8) CRITICAL: Do NOT output role markers like "<|System|>", "<|User|>", or "<|Assistant|>".
|
||||
9) CRITICAL: Do NOT output internal monologues (e.g. "I will list files now..."). Just output your answer or the XML.
|
||||
1) Use the <tool_calls> XML format only. Never emit JSON or function-call syntax.
|
||||
2) Put one or more <tool_call> entries under a single <tool_calls> root.
|
||||
3) Parameters must be XML, not JSON.
|
||||
4) All string values must use <![CDATA[...]]>, even short ones. This includes code, scripts, file contents, prompts, paths, names, and queries.
|
||||
5) Objects use nested XML elements. Arrays may repeat the same tag or use <item> children.
|
||||
6) Numbers, booleans, and null stay plain text.
|
||||
7) Use only the parameter names in the tool schema. Do not invent fields.
|
||||
8) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
|
||||
|
||||
PARAMETER SHAPES:
|
||||
- string => <name><![CDATA[value]]></name>
|
||||
- object => nested XML elements
|
||||
- array => repeated tags or <item> children
|
||||
- number/bool/null => plain text
|
||||
|
||||
【WRONG — Do NOT do these】:
|
||||
|
||||
❌ WRONG — Do NOT do these:
|
||||
Wrong 1 — mixed text after XML:
|
||||
<tool_calls>...</tool_calls> I hope this helps.
|
||||
Wrong 2 — function-call syntax:
|
||||
Grep({"pattern": "token"})
|
||||
Wrong 3 — missing <tool_calls> wrapper:
|
||||
<tool_call><tool_name>` + ex1 + `</tool_name><parameters>{}</parameters></tool_call>
|
||||
Wrong 3 — JSON parameters:
|
||||
<tool_call><tool_name>` + ex1 + `</tool_name><parameters>{"path":"x"}</parameters></tool_call>
|
||||
Wrong 4 — Markdown code fences:
|
||||
` + "```xml" + `
|
||||
<tool_calls>...</tool_calls>
|
||||
` + "```" + `
|
||||
Wrong 5 — native tool tokens:
|
||||
<|Tool|>call_some_tool{"param":1}<|Tool|>
|
||||
Wrong 6 — role markers in response:
|
||||
<|Assistant|> Here is the result...
|
||||
|
||||
Remember: The ONLY valid way to use tools is the <tool_calls> XML block at the end of your response.
|
||||
|
||||
✅ CORRECT EXAMPLES:
|
||||
【CORRECT EXAMPLES】:
|
||||
|
||||
Example A — Single tool:
|
||||
<tool_calls>
|
||||
@@ -96,15 +98,31 @@ Example B — Two tools in parallel:
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
Example C — Tool with complex nested JSON parameters:
|
||||
Example C — Tool with nested XML parameters:
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex3 + `</tool_name>
|
||||
<parameters>` + ex3Params + `</parameters>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
Example D — Tool with long script using CDATA (RELIABLE FOR CODE/SCRIPTS):
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex2 + `</tool_name>
|
||||
<parameters>
|
||||
<path>` + promptCDATA("script.sh") + `</path>
|
||||
<content><![CDATA[
|
||||
#!/bin/bash
|
||||
if [ "$1" == "test" ]; then
|
||||
echo "Success!"
|
||||
fi
|
||||
]]></content>
|
||||
</parameters>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
Remember: Output ONLY the <tool_calls>...</tool_calls> XML block when calling tools.`
|
||||
`
|
||||
}
|
||||
|
||||
func matchAny(name string, candidates ...string) bool {
|
||||
@@ -119,34 +137,44 @@ func matchAny(name string, candidates ...string) bool {
|
||||
func exampleReadParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Read":
|
||||
return `{"file_path":"README.md"}`
|
||||
return `<file_path>` + promptCDATA("README.md") + `</file_path>`
|
||||
case "Glob":
|
||||
return `{"pattern":"**/*.go","path":"."}`
|
||||
return `<pattern>` + promptCDATA("**/*.go") + `</pattern><path>` + promptCDATA(".") + `</path>`
|
||||
default:
|
||||
return `{"path":"src/main.go"}`
|
||||
return `<path>` + promptCDATA("src/main.go") + `</path>`
|
||||
}
|
||||
}
|
||||
|
||||
func exampleWriteOrExecParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Bash", "execute_command":
|
||||
return `{"command":"pwd"}`
|
||||
return `<command>` + promptCDATA("pwd") + `</command>`
|
||||
case "exec_command":
|
||||
return `{"cmd":"pwd"}`
|
||||
return `<cmd>` + promptCDATA("pwd") + `</cmd>`
|
||||
case "Edit":
|
||||
return `{"file_path":"README.md","old_string":"foo","new_string":"bar"}`
|
||||
return `<file_path>` + promptCDATA("README.md") + `</file_path><old_string>` + promptCDATA("foo") + `</old_string><new_string>` + promptCDATA("bar") + `</new_string>`
|
||||
case "MultiEdit":
|
||||
return `{"file_path":"README.md","edits":[{"old_string":"foo","new_string":"bar"}]}`
|
||||
return `<file_path>` + promptCDATA("README.md") + `</file_path><edits><old_string>` + promptCDATA("foo") + `</old_string><new_string>` + promptCDATA("bar") + `</new_string></edits>`
|
||||
default:
|
||||
return `{"path":"output.txt","content":"Hello world"}`
|
||||
return `<path>` + promptCDATA("output.txt") + `</path><content>` + promptCDATA("Hello world") + `</content>`
|
||||
}
|
||||
}
|
||||
|
||||
func exampleInteractiveParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Task":
|
||||
return `{"description":"Investigate flaky tests","prompt":"Run targeted tests and summarize failures"}`
|
||||
return `<description>` + promptCDATA("Investigate flaky tests") + `</description><prompt>` + promptCDATA("Run targeted tests and summarize failures") + `</prompt>`
|
||||
default:
|
||||
return `{"question":"Which approach do you prefer?","follow_up":[{"text":"Option A"},{"text":"Option B"}]}`
|
||||
return `<question>` + promptCDATA("Which approach do you prefer?") + `</question><follow_up><text>` + promptCDATA("Option A") + `</text></follow_up><follow_up><text>` + promptCDATA("Option B") + `</text></follow_up>`
|
||||
}
|
||||
}
|
||||
|
||||
func promptCDATA(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(text, "]]>") {
|
||||
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
|
||||
}
|
||||
return "<![CDATA[" + text + "]]>"
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ func TestBuildToolCallInstructions_ExecCommandUsesCmdExample(t *testing.T) {
|
||||
if !strings.Contains(out, `<tool_name>exec_command</tool_name>`) {
|
||||
t.Fatalf("expected exec_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"cmd":"pwd"}</parameters>`) {
|
||||
if !strings.Contains(out, `<parameters><cmd><![CDATA[pwd]]></cmd></parameters>`) {
|
||||
t.Fatalf("expected cmd parameter example for exec_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@ func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T
|
||||
if !strings.Contains(out, `<tool_name>execute_command</tool_name>`) {
|
||||
t.Fatalf("expected execute_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"command":"pwd"}</parameters>`) {
|
||||
if !strings.Contains(out, `<parameters><command><![CDATA[pwd]]></command></parameters>`) {
|
||||
t.Fatalf("expected command parameter example for execute_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── FormatOpenAIStreamToolCalls ─────────────────────────────────────
|
||||
// --- FormatOpenAIStreamToolCalls ---
|
||||
|
||||
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
|
||||
@@ -22,15 +22,7 @@ func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ParseToolCalls more edge cases ──────────────────────────────────
|
||||
|
||||
func TestParseToolCallsNoToolNames(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
calls := ParseToolCalls(text, nil)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call with nil tool names, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
// --- ParseToolCalls edge cases ---
|
||||
|
||||
func TestParseToolCallsEmptyText(t *testing.T) {
|
||||
calls := ParseToolCalls("", []string{"search"})
|
||||
@@ -38,55 +30,3 @@ func TestParseToolCallsEmptyText(t *testing.T) {
|
||||
t.Fatalf("expected 0 calls for empty text, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsMultipleTools(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search", "get_weather"})
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsInputAsString(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Input["q"] != "golang" {
|
||||
t.Fatalf("expected parsed string input, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithFunctionWrapper(t *testing.T) {
|
||||
text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}`
|
||||
calls := ParseToolCalls(text, []string{"calc"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Name != "calc" {
|
||||
t.Fatalf("expected calc, got %q", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
|
||||
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
|
||||
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected fenced code block to be ignored, got %d calls", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
// ─── looksLikeToolExampleContext ─────────────────────────────────────
|
||||
|
||||
func TestLooksLikeToolExampleContextNone(t *testing.T) {
|
||||
if looksLikeToolExampleContext("I will call the tool now") {
|
||||
t.Fatal("expected false for non-example context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeToolExampleContextFenced(t *testing.T) {
|
||||
if !looksLikeToolExampleContext("```json") {
|
||||
t.Fatal("expected true for fenced code block context")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,205 +1,4 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
var fencedCodeBlockPattern = regexp.MustCompile("(?s)```[\\s\\S]*?```")
|
||||
|
||||
//nolint:unused // retained for future markup tool-call heuristics.
|
||||
var markupToolSyntaxPattern = regexp.MustCompile(`(?i)<(?:(?:[a-z0-9_:-]+:)?(?:tool_call|function_call|invoke)\b|(?:[a-z0-9_:-]+:)?function_calls\b|(?:[a-z0-9_:-]+:)?tool_use\b)`)
|
||||
|
||||
func buildToolCallCandidates(text string) []string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
candidates := []string{trimmed}
|
||||
|
||||
// fenced code block candidates: ```json ... ```
|
||||
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||
if len(match) >= 2 {
|
||||
candidates = append(candidates, strings.TrimSpace(match[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// best-effort extraction around tool call keywords in mixed text payloads.
|
||||
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
||||
|
||||
// best-effort object slice: from first '{' to last '}'
|
||||
first := strings.Index(trimmed, "{")
|
||||
last := strings.LastIndex(trimmed, "}")
|
||||
if first >= 0 && last > first {
|
||||
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
|
||||
}
|
||||
// best-effort array slice: from first '[' to last ']'
|
||||
firstArr := strings.Index(trimmed, "[")
|
||||
lastArr := strings.LastIndex(trimmed, "]")
|
||||
if firstArr >= 0 && lastArr > firstArr {
|
||||
candidates = append(candidates, strings.TrimSpace(trimmed[firstArr:lastArr+1]))
|
||||
}
|
||||
|
||||
// legacy regex extraction fallback
|
||||
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
|
||||
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
|
||||
}
|
||||
|
||||
uniq := make([]string, 0, len(candidates))
|
||||
seen := map[string]struct{}{}
|
||||
for _, c := range candidates {
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
uniq = append(uniq, c)
|
||||
}
|
||||
return uniq
|
||||
}
|
||||
|
||||
func extractToolCallObjects(text string) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
out := []string{}
|
||||
offset := 0
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "functioncall", "\"tool_use\""}
|
||||
for {
|
||||
bestIdx := -1
|
||||
matchedKeyword := ""
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower[offset:], kw)
|
||||
if idx >= 0 {
|
||||
absIdx := offset + idx
|
||||
if bestIdx < 0 || absIdx < bestIdx {
|
||||
bestIdx = absIdx
|
||||
matchedKeyword = kw
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bestIdx < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
idx := bestIdx
|
||||
// Avoid backtracking too far to prevent OOM on malicious or very long strings
|
||||
searchLimit := idx - 2000
|
||||
if searchLimit < offset {
|
||||
searchLimit = offset
|
||||
}
|
||||
|
||||
start := strings.LastIndex(text[searchLimit:idx], "{")
|
||||
if start >= 0 {
|
||||
start += searchLimit
|
||||
}
|
||||
|
||||
if start < 0 {
|
||||
offset = idx + len(matchedKeyword)
|
||||
continue
|
||||
}
|
||||
|
||||
foundObj := false
|
||||
for start >= searchLimit {
|
||||
candidate, end, ok := extractJSONObject(text, start)
|
||||
if ok {
|
||||
// Move forward to avoid repeatedly matching the same object.
|
||||
offset = end
|
||||
out = append(out, strings.TrimSpace(candidate))
|
||||
foundObj = true
|
||||
break
|
||||
}
|
||||
// Try previous '{'
|
||||
if start > searchLimit {
|
||||
prevStart := strings.LastIndex(text[searchLimit:start], "{")
|
||||
if prevStart >= 0 {
|
||||
start = searchLimit + prevStart
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if !foundObj {
|
||||
offset = idx + len(matchedKeyword)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
// Limit scan length to avoid OOM on unclosed objects
|
||||
maxLen := start + 50000
|
||||
if maxLen > len(text) {
|
||||
maxLen = len(text)
|
||||
}
|
||||
for i := start; i < maxLen; i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return text[start : i+1], i + 1, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
if t == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(t, "```")
|
||||
}
|
||||
|
||||
func shouldSkipToolCallParsingForCodeFenceExample(text string) bool {
|
||||
if !looksLikeToolCallSyntax(text) {
|
||||
return false
|
||||
}
|
||||
stripped := strings.TrimSpace(stripFencedCodeBlocks(text))
|
||||
return !looksLikeToolCallSyntax(stripped)
|
||||
}
|
||||
|
||||
//nolint:unused // retained for future markup tool-call heuristics.
|
||||
func looksLikeMarkupToolSyntax(text string) bool {
|
||||
return markupToolSyntaxPattern.MatchString(text)
|
||||
}
|
||||
|
||||
func stripFencedCodeBlocks(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
return fencedCodeBlockPattern.ReplaceAllString(text, " ")
|
||||
}
|
||||
// toolcalls_candidates.go is reserved for tool-call candidate helper logic.
|
||||
// It exists to satisfy the refactor line gate target list.
|
||||
|
||||
@@ -22,6 +22,9 @@ var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
|
||||
"name": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?name>`),
|
||||
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
|
||||
}
|
||||
|
||||
// cdataPattern matches a standalone CDATA section.
|
||||
var cdataPattern = regexp.MustCompile(`(?is)^<!\[CDATA\[(.*?)]]>$`)
|
||||
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
|
||||
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
|
||||
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
|
||||
@@ -68,8 +71,31 @@ func parseMarkupToolCalls(text string) []ParsedToolCall {
|
||||
}
|
||||
|
||||
func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||
if parsed := parseToolCallsPayload(inner); len(parsed) > 0 {
|
||||
return parsed[0]
|
||||
// Try parsing inner content as a JSON tool call object.
|
||||
if raw := strings.TrimSpace(inner); raw != "" && strings.HasPrefix(raw, "{") {
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
||||
name, _ := obj["name"].(string)
|
||||
if name == "" {
|
||||
if fn, ok := obj["function"].(map[string]any); ok {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
if fc, ok := obj["functionCall"].(map[string]any); ok {
|
||||
name, _ = fc["name"].(string)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) != "" {
|
||||
input := parseToolCallInput(obj["input"])
|
||||
if len(input) == 0 {
|
||||
if args, ok := obj["arguments"]; ok {
|
||||
input = parseToolCallInput(args)
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: strings.TrimSpace(name), Input: input}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
name := ""
|
||||
@@ -93,17 +119,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||
}
|
||||
|
||||
func parseMarkupInput(raw string) map[string]any {
|
||||
raw = strings.TrimSpace(html.UnescapeString(raw))
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))}
|
||||
return parseStructuredToolCallInput(raw)
|
||||
}
|
||||
|
||||
func parseMarkupKVObject(text string) map[string]any {
|
||||
@@ -124,16 +140,11 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
if !strings.EqualFold(key, endKey) {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(html.UnescapeString(stripTagText(m[2])))
|
||||
if value == "" {
|
||||
value := parseMarkupValue(m[2])
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
out[key] = jsonValue
|
||||
continue
|
||||
}
|
||||
out[key] = value
|
||||
appendMarkupValue(out, key, value)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
@@ -141,6 +152,67 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseMarkupValue(inner string) any {
|
||||
value := strings.TrimSpace(extractRawTagValue(inner))
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.Contains(value, "<") && strings.Contains(value, ">") {
|
||||
if parsed := parseStructuredToolCallInput(value); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if raw, ok := parsed["_raw"].(string); ok {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
return jsonValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func appendMarkupValue(out map[string]any, key string, value any) {
|
||||
if existing, ok := out[key]; ok {
|
||||
switch current := existing.(type) {
|
||||
case []any:
|
||||
out[key] = append(current, value)
|
||||
default:
|
||||
out[key] = []any{current, value}
|
||||
}
|
||||
return
|
||||
}
|
||||
out[key] = value
|
||||
}
|
||||
|
||||
// extractRawTagValue treats the inner content of a tag robustly.
|
||||
// It detects CDATA and strips it, otherwise it unescapes standard HTML entities.
|
||||
// It avoids over-aggressive tag stripping that might break user content.
|
||||
func extractRawTagValue(inner string) string {
|
||||
trimmed := strings.TrimSpace(inner)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 1. Check for CDATA - if present, it's the ultimate "safe" container.
|
||||
if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 {
|
||||
return cdataMatches[1] // Return raw content between CDATA brackets
|
||||
}
|
||||
|
||||
// 2. If no CDATA, we still want to be robust.
|
||||
// We unescape standard HTML entities (like < > &)
|
||||
// but we DON'T recursively strip tags unless they are actually valid XML tags
|
||||
// at the start/end (which should have been handled by the outer matcher anyway).
|
||||
|
||||
// If it contains what looks like a single tag and no other text, it might be nested XML
|
||||
// but for KV objects we usually want the value.
|
||||
return html.UnescapeString(inner)
|
||||
}
|
||||
|
||||
func stripTagText(text string) string {
|
||||
return strings.TrimSpace(anyTagPattern.ReplaceAllString(text, ""))
|
||||
}
|
||||
@@ -152,7 +224,7 @@ func findMarkupTagValue(text string, tagNames []string, patternByTag map[string]
|
||||
continue
|
||||
}
|
||||
if m := pattern.FindStringSubmatch(text); len(m) >= 2 {
|
||||
value := strings.TrimSpace(m[1])
|
||||
value := extractRawTagValue(m[1])
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user