mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2c6445cfc | ||
|
|
b6fba47bcf | ||
|
|
e8d1aee7ad | ||
|
|
5cf56e7628 | ||
|
|
c291d333c4 | ||
|
|
2788e20f05 | ||
|
|
f178000d69 | ||
|
|
e840743295 | ||
|
|
77484bf813 | ||
|
|
f14969eca5 | ||
|
|
fe8a6bd3cd | ||
|
|
797ab77873 | ||
|
|
8f09e3b381 | ||
|
|
3a79b07d33 | ||
|
|
df13f35f43 | ||
|
|
4422f989be | ||
|
|
6052a8d1e2 | ||
|
|
f125c7ab83 | ||
|
|
8ff923cd77 | ||
|
|
e9a544cc53 | ||
|
|
d848d24a82 | ||
|
|
0a2fc42dad | ||
|
|
e615f1710f | ||
|
|
8f01aa224c | ||
|
|
31e64ff31d | ||
|
|
5984802df4 | ||
|
|
e0ed4ba238 | ||
|
|
ae37654893 | ||
|
|
aa7f821151 | ||
|
|
f7426f9f04 | ||
|
|
787e034174 | ||
|
|
d73f7b8b73 | ||
|
|
b8d844e2f6 | ||
|
|
2ba8b143d0 | ||
|
|
70603a5a90 | ||
|
|
fa51aafdc5 | ||
|
|
10d681ffe7 | ||
|
|
f313d0068f | ||
|
|
12256ceb24 | ||
|
|
2c08375b49 | ||
|
|
fa38934114 | ||
|
|
69eb71159d | ||
|
|
0e7f5cdc86 | ||
|
|
5b7cdaa729 | ||
|
|
08f32c4c40 | ||
|
|
69b7bc0c1a | ||
|
|
0f2b5fee23 | ||
|
|
26d195f2a6 | ||
|
|
790a8ca980 | ||
|
|
a1ce954ad5 | ||
|
|
6688e0ba35 | ||
|
|
c945f49fc4 | ||
|
|
0c644d1f4d | ||
|
|
146d59e7bf | ||
|
|
daf3307b88 | ||
|
|
67501cf4d2 | ||
|
|
25234af301 | ||
|
|
2aee80d0d3 | ||
|
|
ab9f3cc417 | ||
|
|
c92ed8d3c3 | ||
|
|
d78789a66e | ||
|
|
acb110865f | ||
|
|
ffca8be597 | ||
|
|
7ef6a7d11f | ||
|
|
d53a2ea7d2 | ||
|
|
daa636e040 | ||
|
|
aa41bae044 | ||
|
|
2027c7cd77 | ||
|
|
0591128601 | ||
|
|
caafdedb00 | ||
|
|
0a23c77ff7 | ||
|
|
d759804c33 | ||
|
|
433a3a877d | ||
|
|
792e295512 | ||
|
|
d053d9ad04 | ||
|
|
04e025c5e1 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -62,3 +62,8 @@ CLAUDE.local.md
|
||||
|
||||
# Local tool bootstrap cache
|
||||
.tmp/
|
||||
|
||||
# Chat history
|
||||
data/
|
||||
.codex
|
||||
.roomodes
|
||||
|
||||
133
API.en.md
133
API.en.md
@@ -31,13 +31,13 @@ Docs: [Overview](README.en.md) / [Architecture](docs/ARCHITECTURE.en.md) / [Depl
|
||||
| Base URL | `http://localhost:5001` or your deployment domain |
|
||||
| Default Content-Type | `application/json` |
|
||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
### 3.0 Adapter-Layer Notes
|
||||
|
||||
- OpenAI / Claude / Gemini protocols are now mounted on one shared `chi` router tree assembled in `internal/server/router.go`.
|
||||
- Adapter responsibilities are streamlined to: **request normalization → DeepSeek invocation → protocol-shaped rendering**, reducing legacy split-logic paths.
|
||||
- Tool-calling semantics are aligned between Go and Node runtime: structured parsing first (JSON/XML/invoke/markup), plus stream-time anti-leak filtering.
|
||||
- Tool-calling semantics are aligned between Go and Node runtime: parsing is now centered on XML/Markup-family tool syntax (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml variants), plus stream-time anti-leak filtering.
|
||||
- `Admin API` separates static config from runtime policy: `/admin/config*` for configuration state, `/admin/settings*` for runtime behavior.
|
||||
|
||||
---
|
||||
@@ -108,6 +108,7 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) |
|
||||
| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) |
|
||||
| POST | `/v1/embeddings` | Business | OpenAI Embeddings API |
|
||||
| POST | `/v1/files` | Business | OpenAI Files upload (multipart/form-data) |
|
||||
| GET | `/anthropic/v1/models` | None | Claude model list |
|
||||
| POST | `/anthropic/v1/messages` | Business | Claude messages |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
|
||||
@@ -129,11 +130,19 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| POST | `/admin/settings/password` | Admin | Update admin password and invalidate old JWTs |
|
||||
| POST | `/admin/config/import` | Admin | Import config (merge/replace) |
|
||||
| GET | `/admin/config/export` | Admin | Export full config (`config`/`json`/`base64`) |
|
||||
| POST | `/admin/keys` | Admin | Add API key |
|
||||
| POST | `/admin/keys` | Admin | Add API key (optional `name`/`remark`) |
|
||||
| PUT | `/admin/keys/{key}` | Admin | Update API key metadata |
|
||||
| DELETE | `/admin/keys/{key}` | Admin | Delete API key |
|
||||
| GET | `/admin/proxies` | Admin | List proxies |
|
||||
| POST | `/admin/proxies` | Admin | Add proxy |
|
||||
| PUT | `/admin/proxies/{proxyID}` | Admin | Update proxy (empty password keeps old secret) |
|
||||
| DELETE | `/admin/proxies/{proxyID}` | Admin | Delete proxy (auto-unbind referenced accounts) |
|
||||
| POST | `/admin/proxies/test` | Admin | Test proxy connectivity |
|
||||
| GET | `/admin/accounts` | Admin | Paginated account list |
|
||||
| POST | `/admin/accounts` | Admin | Add account |
|
||||
| PUT | `/admin/accounts/{identifier}` | Admin | Update account name/remark |
|
||||
| DELETE | `/admin/accounts/{identifier}` | Admin | Delete account |
|
||||
| PUT | `/admin/accounts/{identifier}/proxy` | Admin | Bind/unbind proxy for an account |
|
||||
| GET | `/admin/queue/status` | Admin | Account queue status |
|
||||
| POST | `/admin/accounts/test` | Admin | Test one account |
|
||||
| POST | `/admin/accounts/test-all` | Admin | Test all accounts |
|
||||
@@ -149,6 +158,10 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| GET | `/admin/export` | Admin | Export config JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | Read local packet-capture entries |
|
||||
| DELETE | `/admin/dev/captures` | Admin | Clear local packet-capture entries |
|
||||
| GET | `/admin/chat-history` | Admin | Read server-side conversation history |
|
||||
| DELETE | `/admin/chat-history` | Admin | Clear server-side conversation history |
|
||||
| DELETE | `/admin/chat-history/{id}` | Admin | Delete one server-side conversation entry |
|
||||
| PUT | `/admin/chat-history/settings` | Admin | Update conversation history retention limit |
|
||||
| GET | `/admin/version` | Admin | Check current version and latest Release |
|
||||
|
||||
---
|
||||
@@ -173,7 +186,7 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
|
||||
### `GET /v1/models`
|
||||
|
||||
No auth required. Returns supported models.
|
||||
No auth required. Returns the currently supported DeepSeek native model list.
|
||||
|
||||
**Response**:
|
||||
|
||||
@@ -184,11 +197,21 @@ No auth required. Returns supported models.
|
||||
{"id": "deepseek-chat", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-reasoner", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-chat-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []}
|
||||
{"id": "deepseek-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-chat", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-reasoner", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-chat-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-chat", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-reasoner", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-chat-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
> Note: `/v1/models` returns normalized DeepSeek native model IDs. Common aliases are accepted only as request input and are not expanded as separate items in this endpoint.
|
||||
|
||||
### Model Alias Resolution
|
||||
|
||||
For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy:
|
||||
@@ -198,6 +221,13 @@ For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-outp
|
||||
3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.).
|
||||
4. If still unmatched, return `invalid_request_error`.
|
||||
|
||||
Current built-in default aliases (excerpt):
|
||||
|
||||
- OpenAI: `gpt-4o`, `gpt-4.1`, `gpt-4.1-mini`, `gpt-4.1-nano`, `gpt-5`, `gpt-5-mini`, `gpt-5-codex`
|
||||
- OpenAI reasoning: `o1`, `o1-mini`, `o3`, `o3-mini`
|
||||
- Claude: `claude-sonnet-4-5`, `claude-haiku-4-5`, `claude-opus-4-6` (plus compatibility aliases `claude-3-5-sonnet` / `claude-3-5-haiku` / `claude-3-opus`)
|
||||
- Gemini: `gemini-2.5-pro`, `gemini-2.5-flash`
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**Headers**:
|
||||
@@ -211,7 +241,7 @@ Content-Type: application/json
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) |
|
||||
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-5`, `gpt-5-mini`, `gpt-5-codex`, `o3`, `claude-opus-4-6`, `gemini-2.5-pro`, `gemini-2.5-flash`, etc.) |
|
||||
| `messages` | array | ✅ | OpenAI-style messages |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Function calling schema |
|
||||
@@ -302,7 +332,12 @@ When `tools` is present, DS2API performs anti-leak handling:
|
||||
}
|
||||
```
|
||||
|
||||
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`.
|
||||
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full argument closure), then keeps sending argument deltas; confirmed tool-call fragments are not forwarded as `delta.content`.
|
||||
|
||||
Additional notes:
|
||||
|
||||
- The parser currently follows XML/Markup-family tool payloads (`<tool_call>`, `<function_call>`, `<invoke>`, `tool_use`, antml variants). Standalone JSON `tool_calls` payloads are not treated as executable tool calls by default.
|
||||
- `tool_calls` shown inside fenced markdown code blocks (for example, ```json ... ```) are treated as examples, not executable calls.
|
||||
|
||||
---
|
||||
|
||||
@@ -381,6 +416,21 @@ Business auth required. Returns OpenAI-compatible embeddings shape.
|
||||
|
||||
> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501.
|
||||
|
||||
### `POST /v1/files`
|
||||
|
||||
Business auth required. OpenAI Files-compatible upload endpoint; currently only `multipart/form-data` is supported.
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `file` | file | ✅ | Binary payload |
|
||||
| `purpose` | string | ❌ | Forwarded purpose field |
|
||||
|
||||
Constraints and behavior:
|
||||
|
||||
- `Content-Type` must be `multipart/form-data` (otherwise `400`).
|
||||
- Total request size limit is `100 MiB` (over-limit returns `413`).
|
||||
- Success returns an OpenAI `file` object (`id/object/bytes/filename/purpose/status`, etc.) and includes `account_id` for source-account tracing.
|
||||
|
||||
---
|
||||
|
||||
## Claude-Compatible API
|
||||
@@ -408,7 +458,7 @@ No auth required.
|
||||
}
|
||||
```
|
||||
|
||||
> Note: the example is partial; the real response includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases.
|
||||
> Note: the example is partial; besides the current primary aliases, the real response also includes Claude 4.x snapshots plus historical 3.x / 2.x / 1.x IDs and common aliases.
|
||||
|
||||
### `POST /anthropic/v1/messages`
|
||||
|
||||
@@ -599,11 +649,15 @@ Returns Vercel preconfiguration status.
|
||||
|
||||
### `GET /admin/config`
|
||||
|
||||
Returns sanitized config.
|
||||
Returns sanitized config, including both `keys` and `api_keys`.
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "Primary", "remark": "Production"},
|
||||
{"key": "k2", "name": "Backup", "remark": "Load test"}
|
||||
],
|
||||
"env_backed": false,
|
||||
"env_source_present": true,
|
||||
"env_writeback_enabled": true,
|
||||
@@ -627,13 +681,18 @@ Returns sanitized config.
|
||||
|
||||
### `POST /admin/config`
|
||||
|
||||
Only updates `keys`, `accounts`, and `claude_mapping`.
|
||||
Only updates `keys`, `api_keys`, `accounts`, and `claude_mapping`.
|
||||
If both `api_keys` and `keys` are sent, the structured `api_keys` entries win so `name` / `remark` metadata is preserved; `keys` remains a legacy fallback.
|
||||
|
||||
**Request**:
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "Primary", "remark": "Production"},
|
||||
{"key": "k2", "name": "Backup", "remark": "Load test"}
|
||||
],
|
||||
"accounts": [
|
||||
{"email": "user@example.com", "password": "pwd", "token": ""}
|
||||
],
|
||||
@@ -693,7 +752,7 @@ Imports full config with:
|
||||
|
||||
The request can send config directly, or wrapped as `{"config": {...}, "mode":"merge"}`.
|
||||
Query params `?mode=merge` / `?mode=replace` are also supported.
|
||||
Import accepts `keys`, `accounts`, `claude_mapping` / `claude_model_mapping`, `model_aliases`, `admin`, `runtime`, `responses`, `embeddings`, and `auto_delete`; legacy `toolcall` fields are ignored.
|
||||
Import accepts `keys`, `api_keys`, `accounts`, `claude_mapping` / `claude_model_mapping`, `model_aliases`, `admin`, `runtime`, `responses`, `embeddings`, and `auto_delete`; legacy `toolcall` fields are ignored.
|
||||
|
||||
> `compat` fields are managed via `/admin/settings` or the config file; this import endpoint does not update `compat`.
|
||||
|
||||
@@ -704,7 +763,17 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
### `POST /admin/keys`
|
||||
|
||||
```json
|
||||
{"key": "new-api-key"}
|
||||
{"key": "new-api-key", "name": "Primary", "remark": "Production"}
|
||||
```
|
||||
|
||||
**Response**: `{"success": true, "total_keys": 3}`
|
||||
|
||||
### `PUT /admin/keys/{key}`
|
||||
|
||||
Updates the `name` / `remark` of the specified API key. The path `key` is read-only and cannot be changed.
|
||||
|
||||
```json
|
||||
{"name": "Backup", "remark": "Load test"}
|
||||
```
|
||||
|
||||
**Response**: `{"success": true, "total_keys": 3}`
|
||||
@@ -713,6 +782,26 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
|
||||
**Response**: `{"success": true, "total_keys": 2}`
|
||||
|
||||
### `GET /admin/proxies`
|
||||
|
||||
Lists proxy configs (password is never returned; use `has_password` as a marker).
|
||||
|
||||
### `POST /admin/proxies`
|
||||
|
||||
Adds a proxy. Request accepts `id` (optional; auto-generated when omitted), `name`, `type` (`http` / `socks5`), `host`, `port`, `username`, `password`.
|
||||
|
||||
### `PUT /admin/proxies/{proxyID}`
|
||||
|
||||
Updates a proxy. If `password` is an empty string, the existing secret is preserved.
|
||||
|
||||
### `DELETE /admin/proxies/{proxyID}`
|
||||
|
||||
Deletes a proxy and automatically clears `proxy_id` on all accounts that reference it.
|
||||
|
||||
### `POST /admin/proxies/test`
|
||||
|
||||
Tests proxy connectivity: provide `proxy_id` to test a saved proxy; omit it to run a one-off test using proxy fields in the request body.
|
||||
|
||||
### `GET /admin/accounts`
|
||||
|
||||
**Query params**:
|
||||
@@ -720,7 +809,7 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
| Param | Default | Range |
|
||||
| --- | --- | --- |
|
||||
| `page` | `1` | ≥ 1 |
|
||||
| `page_size` | `10` | 1–100 |
|
||||
| `page_size` | `10` | 1–5000 |
|
||||
| `q` | empty | Filter by identifier / email / mobile |
|
||||
|
||||
**Response**:
|
||||
@@ -755,12 +844,30 @@ Returned items also include `test_status`, usually `ok` or `failed`.
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}`
|
||||
|
||||
Updates the `name` / `remark` of the specified account. The path `identifier` can be email or mobile and cannot be changed.
|
||||
|
||||
```json
|
||||
{"name": "Primary account", "remark": "Shared with the team"}
|
||||
```
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:<hash>`).
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 5}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}/proxy`
|
||||
|
||||
Updates proxy binding for a specific account.
|
||||
|
||||
- Request body: `{"proxy_id":"..."}`.
|
||||
- Use empty `proxy_id` to unbind proxy.
|
||||
- `identifier` supports email / mobile / token-only synthetic id.
|
||||
|
||||
### `GET /admin/queue/status`
|
||||
|
||||
```json
|
||||
|
||||
135
API.md
135
API.md
@@ -31,13 +31,13 @@
|
||||
| Base URL | `http://localhost:5001` 或你的部署域名 |
|
||||
| 默认 Content-Type | `application/json` |
|
||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
### 3.0 接口适配层说明
|
||||
|
||||
- OpenAI / Claude / Gemini 三套协议已统一挂在同一 `chi` 路由树上,由 `internal/server/router.go` 负责装配。
|
||||
- 适配器层职责收敛为:**请求归一化 → DeepSeek 调用 → 协议形态渲染**,减少历史版本中“同能力多处实现”的分叉。
|
||||
- Tool Calling 的解析策略在 Go 与 Node Runtime 间保持一致:优先结构化解析(JSON/XML/invoke/markup),并在流式场景执行防泄漏筛分。
|
||||
- Tool Calling 的解析策略在 Go 与 Node Runtime 间保持一致:当前以 XML/Markup 家族解析为主(含 `<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体),并在流式场景执行防泄漏筛分。
|
||||
- `Admin API` 将配置与运行时策略分开:`/admin/config*` 管静态配置,`/admin/settings*` 管运行时行为。
|
||||
|
||||
---
|
||||
@@ -108,6 +108,7 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) |
|
||||
| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) |
|
||||
| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 |
|
||||
| POST | `/v1/files` | 业务 | OpenAI Files 上传(multipart/form-data) |
|
||||
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
|
||||
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
|
||||
@@ -129,11 +130,19 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| POST | `/admin/settings/password` | Admin | 更新 Admin 密码并使旧 JWT 失效 |
|
||||
| POST | `/admin/config/import` | Admin | 导入配置(merge/replace) |
|
||||
| GET | `/admin/config/export` | Admin | 导出完整配置(含 `config`/`json`/`base64`) |
|
||||
| POST | `/admin/keys` | Admin | 添加 API key |
|
||||
| POST | `/admin/keys` | Admin | 添加 API key(可附 name/remark) |
|
||||
| PUT | `/admin/keys/{key}` | Admin | 更新 API key 备注信息 |
|
||||
| DELETE | `/admin/keys/{key}` | Admin | 删除 API key |
|
||||
| GET | `/admin/proxies` | Admin | 代理列表 |
|
||||
| POST | `/admin/proxies` | Admin | 添加代理 |
|
||||
| PUT | `/admin/proxies/{proxyID}` | Admin | 更新代理(留空 password 表示保留原密码) |
|
||||
| DELETE | `/admin/proxies/{proxyID}` | Admin | 删除代理(自动解绑引用该代理的账号) |
|
||||
| POST | `/admin/proxies/test` | Admin | 测试代理连通性 |
|
||||
| GET | `/admin/accounts` | Admin | 分页账号列表 |
|
||||
| POST | `/admin/accounts` | Admin | 添加账号 |
|
||||
| PUT | `/admin/accounts/{identifier}` | Admin | 更新账号 name/remark |
|
||||
| DELETE | `/admin/accounts/{identifier}` | Admin | 删除账号 |
|
||||
| PUT | `/admin/accounts/{identifier}/proxy` | Admin | 为账号绑定/解绑代理 |
|
||||
| GET | `/admin/queue/status` | Admin | 账号队列状态 |
|
||||
| POST | `/admin/accounts/test` | Admin | 测试单个账号 |
|
||||
| POST | `/admin/accounts/test-all` | Admin | 测试全部账号 |
|
||||
@@ -149,6 +158,10 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| GET | `/admin/export` | Admin | 导出配置 JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | 查看本地抓包记录 |
|
||||
| DELETE | `/admin/dev/captures` | Admin | 清空本地抓包记录 |
|
||||
| GET | `/admin/chat-history` | Admin | 查看服务器端对话记录 |
|
||||
| DELETE | `/admin/chat-history` | Admin | 清空服务器端对话记录 |
|
||||
| DELETE | `/admin/chat-history/{id}` | Admin | 删除单条服务器端对话记录 |
|
||||
| PUT | `/admin/chat-history/settings` | Admin | 更新对话记录保留条数 |
|
||||
| GET | `/admin/version` | Admin | 查询当前版本与最新 Release |
|
||||
|
||||
---
|
||||
@@ -173,7 +186,7 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
|
||||
### `GET /v1/models`
|
||||
|
||||
无需鉴权。返回当前支持的模型列表。
|
||||
无需鉴权。返回当前支持的 DeepSeek 原生模型列表。
|
||||
|
||||
**响应示例**:
|
||||
|
||||
@@ -184,11 +197,21 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
{"id": "deepseek-chat", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-reasoner", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-chat-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []}
|
||||
{"id": "deepseek-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-chat", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-reasoner", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-chat-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-expert-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-chat", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-reasoner", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-chat-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []},
|
||||
{"id": "deepseek-vision-reasoner-search", "object": "model", "created": 1677610602, "owned_by": "deepseek", "permission": []}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
> 说明:`/v1/models` 返回的是规范化后的 DeepSeek 原生模型 ID;常见 alias 仅用于请求入参解析,不会在该接口中单独展开返回。
|
||||
|
||||
### 模型 alias 解析策略
|
||||
|
||||
对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”:
|
||||
@@ -198,6 +221,13 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。
|
||||
4. 仍未命中则返回 `invalid_request_error`。
|
||||
|
||||
当前内置默认 alias(节选):
|
||||
|
||||
- OpenAI:`gpt-4o`、`gpt-4.1`、`gpt-4.1-mini`、`gpt-4.1-nano`、`gpt-5`、`gpt-5-mini`、`gpt-5-codex`
|
||||
- OpenAI Reasoning:`o1`、`o1-mini`、`o3`、`o3-mini`
|
||||
- Claude:`claude-sonnet-4-5`、`claude-haiku-4-5`、`claude-opus-4-6`(及 `claude-3-5-sonnet` / `claude-3-5-haiku` / `claude-3-opus` 兼容别名)
|
||||
- Gemini:`gemini-2.5-pro`、`gemini-2.5-flash`
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**请求头**:
|
||||
@@ -211,7 +241,7 @@ Content-Type: application/json
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) |
|
||||
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-5`、`gpt-5-mini`、`gpt-5-codex`、`o3`、`claude-opus-4-6`、`gemini-2.5-pro`、`gemini-2.5-flash` 等) |
|
||||
| `messages` | array | ✅ | OpenAI 风格消息数组 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | Function Calling 定义 |
|
||||
@@ -302,12 +332,12 @@ data: [DONE]
|
||||
}
|
||||
```
|
||||
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整工具参数闭合),并持续发送 arguments 增量;已确认的工具调用片段不会回流到 `delta.content`。
|
||||
|
||||
补充说明:
|
||||
|
||||
- **非代码块上下文**下,工具负载即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||
- 解析器以 XML/Markup 为最高优先级,并兼容 JSON、ANTML、text-kv 等格式输入;最终按客户端协议转译为对应 tool call 结构(OpenAI/Claude/Gemini)。
|
||||
- 解析器当前走 XML/Markup 家族(包含 `<tool_call>`、`<function_call>`、`<invoke>`、`tool_use`、antml 风格);纯 JSON `tool_calls` 片段默认不会直接作为可执行调用解析。
|
||||
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||
|
||||
---
|
||||
@@ -387,6 +417,21 @@ data: [DONE]
|
||||
|
||||
> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。
|
||||
|
||||
### `POST /v1/files`
|
||||
|
||||
需要业务鉴权。兼容 OpenAI Files 上传接口,当前仅支持 `multipart/form-data`。
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `file` | file | ✅ | 上传文件二进制 |
|
||||
| `purpose` | string | ❌ | 透传到上游用途字段 |
|
||||
|
||||
约束与行为:
|
||||
|
||||
- 请求必须为 `multipart/form-data`,否则返回 `400`。
|
||||
- 请求体总大小上限 `100 MiB`(超限返回 `413`)。
|
||||
- 成功返回 OpenAI `file` 对象(`id/object/bytes/filename/purpose/status` 等字段),并附带 `account_id` 便于定位来源账号。
|
||||
|
||||
---
|
||||
|
||||
## Claude 兼容接口
|
||||
@@ -414,7 +459,7 @@ data: [DONE]
|
||||
}
|
||||
```
|
||||
|
||||
> 说明:示例仅展示部分模型;实际返回包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名。
|
||||
> 说明:示例仅展示部分模型;实际返回除当前主别名外,还包含 Claude 4.x snapshots,以及 3.x / 2.x / 1.x 历史模型 ID 与常见别名。
|
||||
|
||||
### `POST /anthropic/v1/messages`
|
||||
|
||||
@@ -605,11 +650,15 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `GET /admin/config`
|
||||
|
||||
返回脱敏后的配置。
|
||||
返回脱敏后的配置,包含 `keys` 与 `api_keys`。
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "主 Key", "remark": "生产流量"},
|
||||
{"key": "k2", "name": "备用 Key", "remark": "压测"}
|
||||
],
|
||||
"env_backed": false,
|
||||
"env_source_present": true,
|
||||
"env_writeback_enabled": true,
|
||||
@@ -633,13 +682,18 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `POST /admin/config`
|
||||
|
||||
只更新 `keys`、`accounts`、`claude_mapping`。
|
||||
只更新 `keys`、`api_keys`、`accounts`、`claude_mapping`。
|
||||
如果同时发送 `api_keys` 与 `keys`,优先保留 `api_keys` 中的结构化 `name` / `remark`;`keys` 仅作为旧格式兼容回退。
|
||||
|
||||
**请求**:
|
||||
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"api_keys": [
|
||||
{"key": "k1", "name": "主 Key", "remark": "生产流量"},
|
||||
{"key": "k2", "name": "备用 Key", "remark": "压测"}
|
||||
],
|
||||
"accounts": [
|
||||
{"email": "user@example.com", "password": "pwd", "token": ""}
|
||||
],
|
||||
@@ -699,7 +753,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
请求可直接传配置对象,或使用 `{"config": {...}, "mode":"merge"}` 包裹格式。
|
||||
也支持在查询参数里传 `?mode=merge` / `?mode=replace`。
|
||||
导入时会接受 `keys`、`accounts`、`claude_mapping` / `claude_model_mapping`、`model_aliases`、`admin`、`runtime`、`responses`、`embeddings`、`auto_delete` 等字段;`toolcall` 相关字段会被忽略。
|
||||
导入时会接受 `keys`、`api_keys`、`accounts`、`claude_mapping` / `claude_model_mapping`、`model_aliases`、`admin`、`runtime`、`responses`、`embeddings`、`auto_delete` 等字段;`toolcall` 相关字段会被忽略。
|
||||
|
||||
> `compat` 相关字段请通过 `/admin/settings` 或配置文件管理;该导入接口不会更新 `compat`。
|
||||
|
||||
@@ -707,10 +761,25 @@ data: {"type":"message_stop"}
|
||||
|
||||
导出完整配置,返回 `config`、`json`、`base64` 三种格式。
|
||||
|
||||
响应示例:
|
||||
|
||||
|
||||
> 注:`_vercel_sync_hash` 和 `_vercel_sync_time` 为内部同步元数据字段,用于 Vercel 配置漂移检测。
|
||||
|
||||
### `POST /admin/keys`
|
||||
|
||||
```json
|
||||
{"key": "new-api-key"}
|
||||
{"key": "new-api-key", "name": "主 Key", "remark": "生产流量"}
|
||||
```
|
||||
|
||||
**响应**:`{"success": true, "total_keys": 3}`
|
||||
|
||||
### `PUT /admin/keys/{key}`
|
||||
|
||||
更新指定 API key 的 `name` / `remark`,路径参数中的 `key` 为只读标识,不可修改。
|
||||
|
||||
```json
|
||||
{"name": "备用 Key", "remark": "压测"}
|
||||
```
|
||||
|
||||
**响应**:`{"success": true, "total_keys": 3}`
|
||||
@@ -719,6 +788,26 @@ data: {"type":"message_stop"}
|
||||
|
||||
**响应**:`{"success": true, "total_keys": 2}`
|
||||
|
||||
### `GET /admin/proxies`
|
||||
|
||||
列出代理配置(密码不回传,仅返回 `has_password` 标记)。
|
||||
|
||||
### `POST /admin/proxies`
|
||||
|
||||
新增代理。请求体支持 `id`(可选,未传则自动生成)、`name`、`type`(`http` / `socks5`)、`host`、`port`、`username`、`password`。
|
||||
|
||||
### `PUT /admin/proxies/{proxyID}`
|
||||
|
||||
更新指定代理。若请求中 `password` 为空字符串,则保留原密码。
|
||||
|
||||
### `DELETE /admin/proxies/{proxyID}`
|
||||
|
||||
删除代理,并自动清空所有引用该代理账号的 `proxy_id`。
|
||||
|
||||
### `POST /admin/proxies/test`
|
||||
|
||||
测试代理连通性:传 `proxy_id` 时测试已保存代理;不传时按请求体代理字段做临时连通性测试。
|
||||
|
||||
### `GET /admin/accounts`
|
||||
|
||||
**查询参数**:
|
||||
@@ -726,7 +815,7 @@ data: {"type":"message_stop"}
|
||||
| 参数 | 默认 | 范围 |
|
||||
| --- | --- | --- |
|
||||
| `page` | `1` | ≥ 1 |
|
||||
| `page_size` | `10` | 1–100 |
|
||||
| `page_size` | `10` | 1–5000 |
|
||||
| `q` | 空 | 按 identifier / email / mobile 过滤 |
|
||||
|
||||
**响应**:
|
||||
@@ -759,12 +848,30 @@ data: {"type":"message_stop"}
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}`
|
||||
|
||||
更新指定账号的 `name` / `remark`。路径参数中的 `identifier` 可以是 email 或 mobile,且不可修改。
|
||||
|
||||
```json
|
||||
{"name": "主账号", "remark": "团队共享"}
|
||||
```
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 6}`
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:<hash>`)。
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 5}`
|
||||
|
||||
### `PUT /admin/accounts/{identifier}/proxy`
|
||||
|
||||
更新指定账号绑定代理。
|
||||
|
||||
- 请求体:`{"proxy_id":"..."}`;
|
||||
- `proxy_id` 传空字符串时表示解绑代理;
|
||||
- `identifier` 支持 email / mobile / token-only 合成标识。
|
||||
|
||||
### `GET /admin/queue/status`
|
||||
|
||||
```json
|
||||
|
||||
143
LICENSE
143
LICENSE
@@ -1,5 +1,5 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
@@ -7,17 +7,15 @@
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
@@ -26,44 +24,34 @@ them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
@@ -72,7 +60,7 @@ modification follow.
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
@@ -549,35 +537,45 @@ to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
@@ -635,40 +633,29 @@ the "copyright" line and a pointer to where the full notice is found.
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
it under the terms of the GNU Affero General Public License as published
|
||||
by the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
|
||||
155
README.MD
155
README.MD
@@ -18,6 +18,8 @@
|
||||
|
||||
文档入口:[文档导航](docs/README.md) / [架构说明](docs/ARCHITECTURE.md) / [接口文档](API.md)
|
||||
|
||||
【感谢Linux.do社区及GitHub社区各位开发者对项目的支持与贡献】
|
||||
|
||||
> **重要免责声明**
|
||||
>
|
||||
> 本仓库仅供学习、研究、个人实验和内部验证使用,不提供任何形式的商业授权、适用性保证或结果保证。
|
||||
@@ -80,29 +82,19 @@ flowchart LR
|
||||
- **前端**:React 管理台(`webui/`),运行时托管静态构建产物
|
||||
- **部署**:本地运行、Docker、Vercel Serverless、Linux systemd
|
||||
|
||||
### 3.X 底层架构调整(相较旧版本)
|
||||
|
||||
- **统一路由内核**:所有协议入口统一汇聚到 `internal/server/router.go`,并在同一路由树中注册 OpenAI / Claude / Gemini / Admin / WebUI 路由,避免多入口行为漂移。
|
||||
- **统一执行链路**:Claude / Gemini 入口先经 `internal/translatorcliproxy` 做协议转换,再进入 `openai.ChatCompletions` 统一处理工具调用与流式语义,最后再转换回原协议响应。
|
||||
- **适配器分层更清晰**:`internal/adapter/{claude,gemini}` 负责入口/出口协议封装,`internal/adapter/openai` 负责核心执行,DeepSeek 侧调用只保留在 OpenAI 内核中。
|
||||
- **Tool Calling 双运行时对齐**:Go 侧(`internal/toolcall`)与 Vercel Node 侧(`internal/js/helpers/stream-tool-sieve`)保持一致的解析/防泄漏语义,覆盖 JSON / XML / invoke / text-kv 多风格输入。
|
||||
- **配置与运行时设置解耦**:静态配置(`config`)与运行时策略(`settings`)通过 Admin API 分离管理,支持热更新和密码轮换失效旧 JWT。
|
||||
- **流式能力升级**:`/v1/responses` 与 `/v1/chat/completions` 共享更一致的工具调用增量输出策略,降低不同 SDK 下的行为差异。
|
||||
- **可观测与可运维增强**:`/healthz`、`/readyz`、`/admin/version`、`/admin/dev/captures` 形成排障闭环,便于发布后验证。
|
||||
|
||||
## 核心能力
|
||||
|
||||
| 能力 | 说明 |
|
||||
| --- | --- |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings`、`POST /v1/files` |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) |
|
||||
| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) |
|
||||
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||
| DeepSeek PoW | 纯 Go 高性能实现(DeepSeekHashV1),毫秒级响应 |
|
||||
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、会话清理、导入导出、Vercel 同步、版本检查 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||
| Admin API | 配置管理、运行时设置热更新、代理管理、账号测试 / 批量测试、会话清理、导入导出、Vercel 同步、版本检查 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式,支持查看服务器端对话记录) |
|
||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||
|
||||
## 平台兼容矩阵
|
||||
@@ -118,33 +110,42 @@ flowchart LR
|
||||
|
||||
## 模型支持
|
||||
|
||||
### OpenAI 接口
|
||||
### OpenAI 接口(`GET /v1/models`)
|
||||
|
||||
| 模型 | thinking | search |
|
||||
| --- | --- | --- |
|
||||
| `deepseek-chat` | ❌ | ❌ |
|
||||
| `deepseek-reasoner` | ✅ | ❌ |
|
||||
| `deepseek-chat-search` | ❌ | ✅ |
|
||||
| `deepseek-reasoner-search` | ✅ | ✅ |
|
||||
| 模型类型 | 模型 ID | thinking | search |
|
||||
| --- | --- | --- | --- |
|
||||
| default | `deepseek-chat` | ❌ | ❌ |
|
||||
| default | `deepseek-reasoner` | ✅ | ❌ |
|
||||
| default | `deepseek-chat-search` | ❌ | ✅ |
|
||||
| default | `deepseek-reasoner-search` | ✅ | ✅ |
|
||||
| expert | `deepseek-expert-chat` | ❌ | ❌ |
|
||||
| expert | `deepseek-expert-reasoner` | ✅ | ❌ |
|
||||
| expert | `deepseek-expert-chat-search` | ❌ | ✅ |
|
||||
| expert | `deepseek-expert-reasoner-search` | ✅ | ✅ |
|
||||
| vision | `deepseek-vision-chat` | ❌ | ❌ |
|
||||
| vision | `deepseek-vision-reasoner` | ✅ | ❌ |
|
||||
| vision | `deepseek-vision-chat-search` | ❌ | ✅ |
|
||||
| vision | `deepseek-vision-reasoner-search` | ✅ | ✅ |
|
||||
|
||||
### Claude 接口
|
||||
除原生模型外,也支持常见 alias 输入(如 `gpt-5`、`gpt-5-mini`、`gpt-5-codex`、`gpt-4.1`、`o3`、`claude-opus-4-6`、`claude-sonnet-4-5`、`gemini-2.5-pro`、`gemini-2.5-flash` 等),但 `/v1/models` 返回的是规范化后的 DeepSeek 原生模型 ID。
|
||||
|
||||
| 模型 | 默认映射 |
|
||||
### Claude 接口(`GET /anthropic/v1/models`)
|
||||
|
||||
| 当前常用模型 | 默认映射 |
|
||||
| --- | --- |
|
||||
| `claude-sonnet-4-5` | `deepseek-chat` |
|
||||
| `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`) | `deepseek-chat` |
|
||||
| `claude-opus-4-6` | `deepseek-reasoner` |
|
||||
|
||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||
|
||||
`/anthropic/v1/models` 除上述当前主别名外,还会返回 Claude 4.x snapshots,以及 3.x / 2.x / 1.x 历史模型 ID 与常见 alias,便于旧客户端直接兼容。
|
||||
|
||||
#### Claude Code 接入避坑(实测)
|
||||
|
||||
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请优先检查模型输出是否为受支持的 XML/Markup 工具块(例如 `<tool_call>` / `<function_call>` / `<invoke>` / `tool_use`),而不是纯 JSON `tool_calls` 片段。
|
||||
|
||||
### Gemini 接口
|
||||
|
||||
@@ -152,6 +153,15 @@ Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 Deep
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 部署方式优先级建议
|
||||
|
||||
推荐按以下顺序选择部署方式:
|
||||
|
||||
1. **下载 Release 构建包运行**:最省事,产物已编译完成,最适合大多数用户。
|
||||
2. **Docker / GHCR 镜像部署**:适合需要容器化、编排或云环境部署。
|
||||
3. **Vercel 部署**:适合已有 Vercel 环境且接受其平台约束的场景。
|
||||
4. **本地源码运行 / 自行编译**:适合开发、调试或需要自行修改代码的场景。
|
||||
|
||||
### 通用第一步(所有部署方式)
|
||||
|
||||
把 `config.json` 作为唯一配置源(推荐做法):
|
||||
@@ -165,29 +175,21 @@ cp config.example.json config.json
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量,也可以直接写原始 JSON
|
||||
|
||||
### 方式一:本地运行
|
||||
WebUI 管理台里的“全量配置模板”也直接复用同一份 `config.example.json`,所以更新示例文件后,前端模板会自动保持一致。
|
||||
|
||||
**前置要求**:Go 1.26+,Node.js `20.19+` 或 `22.12+`(仅在需要构建 WebUI 时)
|
||||
### 方式一:下载 Release 构建包
|
||||
|
||||
每次发布 Release 时,GitHub Actions 会自动构建多平台二进制包:
|
||||
|
||||
```bash
|
||||
# 1. 克隆仓库
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 2. 配置
|
||||
# 下载对应平台的压缩包后
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json,填入你的 DeepSeek 账号信息和 API key
|
||||
|
||||
# 3. 启动
|
||||
go run ./cmd/ds2api
|
||||
# 编辑 config.json
|
||||
./ds2api
|
||||
```
|
||||
|
||||
默认本地访问地址:`http://127.0.0.1:5001`
|
||||
|
||||
服务实际绑定:`0.0.0.0:5001`,因此同一局域网设备通常也可以通过你的内网 IP 访问。
|
||||
|
||||
> **WebUI 自动构建**:本地首次启动时,若 `static/admin` 不存在,会自动尝试执行 `npm ci`(仅在缺少依赖时)和 `npm run build -- --outDir static/admin --emptyOutDir`(需要本机有 Node.js)。你也可以手动构建:`./scripts/build-webui.sh`
|
||||
|
||||
### 方式二:Docker 运行
|
||||
|
||||
```bash
|
||||
@@ -241,35 +243,28 @@ base64 < config.json | tr -d '\n'
|
||||
|
||||
详细部署说明请参阅 [部署指南](docs/DEPLOY.md)。
|
||||
|
||||
### 方式四:下载 Release 构建包
|
||||
### 方式四:本地源码运行
|
||||
|
||||
每次发布 Release 时,GitHub Actions 会自动构建多平台二进制包:
|
||||
**前置要求**:Go 1.26+,Node.js `20.19+` 或 `22.12+`(仅在需要构建 WebUI 时)
|
||||
|
||||
```bash
|
||||
# 下载对应平台的压缩包后
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
# 1. 克隆仓库
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 2. 配置
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
./ds2api
|
||||
# 编辑 config.json,填入你的 DeepSeek 账号信息和 API key
|
||||
|
||||
# 3. 启动
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### 方式五:OpenCode CLI 接入
|
||||
默认本地访问地址:`http://127.0.0.1:5001`
|
||||
|
||||
1. 复制示例配置:
|
||||
服务实际绑定:`0.0.0.0:5001`,因此同一局域网设备通常也可以通过你的内网 IP 访问。
|
||||
|
||||
```bash
|
||||
cp opencode.json.example opencode.json
|
||||
```
|
||||
|
||||
2. 编辑 `opencode.json`:
|
||||
- 将 `baseURL` 改为你的 DS2API 地址(例如 `https://your-domain.com/v1`)
|
||||
- 将 `apiKey` 改为你的 DS2API key(对应 `config.keys`)
|
||||
|
||||
3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。
|
||||
|
||||
> 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。
|
||||
> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。
|
||||
> **WebUI 自动构建**:本地首次启动时,若 `static/admin` 不存在,会自动尝试执行 `npm ci`(仅在缺少依赖时)和 `npm run build -- --outDir static/admin --emptyOutDir`(需要本机有 Node.js)。你也可以手动构建:`./scripts/build-webui.sh`
|
||||
|
||||
## 配置说明
|
||||
|
||||
@@ -278,8 +273,17 @@ cp opencode.json.example opencode.json
|
||||
```json
|
||||
{
|
||||
"keys": ["your-api-key-1", "your-api-key-2"],
|
||||
"api_keys": [
|
||||
{
|
||||
"key": "your-api-key-1",
|
||||
"name": "主 Key",
|
||||
"remark": "生产流量"
|
||||
}
|
||||
],
|
||||
"accounts": [
|
||||
{
|
||||
"name": "账号 A",
|
||||
"remark": "主账号",
|
||||
"email": "user@example.com",
|
||||
"password": "your-password"
|
||||
},
|
||||
@@ -290,8 +294,12 @@ cp opencode.json.example opencode.json
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5": "deepseek-chat",
|
||||
"gpt-5-mini": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
"o3": "deepseek-reasoner",
|
||||
"claude-opus-4-6": "deepseek-reasoner",
|
||||
"gemini-2.5-flash": "deepseek-chat"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true,
|
||||
@@ -323,7 +331,8 @@ cp opencode.json.example opencode.json
|
||||
```
|
||||
|
||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||
- `api_keys`:推荐使用的新结构化密钥列表,支持 `key` + `name` + `remark`(`keys` 仍兼容)
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录;可额外填写 `name` / `remark` 便于管理
|
||||
- `token`:配置文件中即使填写也会在加载时被清空(不会从 `config.json` 读取 token);实际 token 仅在运行时内存中维护并自动刷新
|
||||
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||
@@ -338,6 +347,8 @@ cp opencode.json.example opencode.json
|
||||
|
||||
### 环境变量
|
||||
|
||||
> 建议:长期维护请优先以 `config.json`(或其 Base64)为单一配置源。环境变量仅保留部署必需项;`DS2API_CONFIG_JSON` 主要用于 Vercel/无持久盘场景,后续可能进一步收敛。
|
||||
|
||||
| 变量 | 用途 | 默认值 |
|
||||
| --- | --- | --- |
|
||||
| `PORT` | 服务端口 | `5001` |
|
||||
@@ -347,6 +358,7 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_JWT_EXPIRE_HOURS` | Admin JWT 过期小时数 | `24` |
|
||||
| `DS2API_CONFIG_PATH` | 配置文件路径 | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | 直接注入配置(JSON 或 Base64) | — |
|
||||
| `DS2API_CHAT_HISTORY_PATH` | 服务器端对话记录文件路径 | `data/chat_history.json` |
|
||||
| `DS2API_ENV_WRITEBACK` | 环境变量模式下自动写回配置文件并切换文件模式(`1/true/yes/on`) | 关闭 |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | 管理台静态文件目录 | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | 启动时自动构建 WebUI | 本地开启,Vercel 关闭 |
|
||||
@@ -365,6 +377,15 @@ cp opencode.json.example opencode.json
|
||||
|
||||
> 提示:当检测到 `DS2API_CONFIG_JSON` 时,管理台会显示当前模式风险与自动持久化状态(含 `DS2API_CONFIG_PATH` 路径与模式切换说明)。
|
||||
|
||||
#### 必填 / 可选(按部署方式)
|
||||
|
||||
- **所有部署都必填**:`DS2API_ADMIN_KEY`
|
||||
- **配置来源二选一(推荐前者)**:
|
||||
- `config.json` 文件(推荐,持久化更直观)
|
||||
- `DS2API_CONFIG_JSON`(可选,适合 Vercel;支持 JSON 或 Base64)
|
||||
- **仅在环境变量配置模式建议开启**:`DS2API_ENV_WRITEBACK=1`(避免管理台改动重启后丢失)
|
||||
- 其余环境变量均为可选调优项。
|
||||
|
||||
## 鉴权模式
|
||||
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
|
||||
@@ -395,7 +416,7 @@ Gemini 路由还可以使用 `x-goog-api-key`,或在没有认证头时使用 `
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理与结构化转译:
|
||||
|
||||
1. 只在**非代码块上下文**启用执行型 toolcall 识别(代码块示例默认不触发)
|
||||
2. 解析层以 XML/Markup 为最高优先级,同时兼容 JSON / ANTML / invoke / text-kv,并统一归一到内部工具调用结构
|
||||
2. 解析层当前以 XML/Markup 家族为准(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体);纯 JSON `tool_calls` 片段默认不作为可执行调用解析
|
||||
3. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||
4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
|
||||
5. 客户端请求哪种协议,就按该协议返回工具调用(OpenAI/Claude/Gemini 各自原生结构);模型侧优先约束输出规范 XML,再由兼容层转译
|
||||
@@ -494,7 +515,7 @@ go test -v -run 'TestParseToolCalls|TestRepair' ./internal/toolcall/
|
||||
- **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发)
|
||||
- **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`)+ `sha256sums.txt`
|
||||
- **容器镜像发布**:仅推送到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、配置示例、README、LICENSE
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、`config.example.json` 配置示例、README、LICENSE
|
||||
|
||||
## 免责声明
|
||||
|
||||
|
||||
150
README.en.md
150
README.en.md
@@ -80,29 +80,19 @@ For the full module-by-module architecture and directory responsibilities, see [
|
||||
- **Frontend**: React admin panel (`webui/`), served as static build at runtime
|
||||
- **Deployment**: local run, Docker, Vercel serverless, Linux systemd
|
||||
|
||||
### 3.X Architecture Changes (vs older releases)
|
||||
|
||||
- **Unified routing core**: all protocol entries are now centralized through `internal/server/router.go`, with OpenAI / Claude / Gemini / Admin / WebUI routes registered in one tree to avoid multi-entry drift.
|
||||
- **Unified execution chain**: Claude/Gemini entries are translated by `internal/translatorcliproxy`, then executed through `openai.ChatCompletions` for shared tool-calling and stream semantics, then translated back to the client protocol.
|
||||
- **Cleaner adapter boundaries**: `internal/adapter/{claude,gemini}` handles protocol wrappers, while `internal/adapter/openai` remains the execution core; upstream DeepSeek calls are retained only in the OpenAI core.
|
||||
- **Tool-calling parity across runtimes**: Go (`internal/toolcall`) and Vercel Node (`internal/js/helpers/stream-tool-sieve`) follow aligned parsing/anti-leak semantics across JSON / XML / invoke / text-kv inputs.
|
||||
- **Config/runtime separation**: static config (`config`) and runtime policy (`settings`) are managed independently via Admin APIs, enabling hot updates and password rotation with JWT invalidation.
|
||||
- **Streaming behavior upgrade**: `/v1/responses` and `/v1/chat/completions` now share a more consistent incremental tool-call emission strategy across SDK ecosystems.
|
||||
- **Improved operability**: `/healthz`, `/readyz`, `/admin/version`, and `/admin/dev/captures` form a tighter post-deploy diagnostics loop.
|
||||
|
||||
## Key Capabilities
|
||||
|
||||
| Capability | Details |
|
||||
| --- | --- |
|
||||
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` |
|
||||
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings`, `POST /v1/files` |
|
||||
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` (plus shortcut paths `/v1/messages`, `/messages`) |
|
||||
| Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) |
|
||||
| Multi-account rotation | Auto token refresh, email/mobile dual login |
|
||||
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
||||
| DeepSeek PoW | Pure Go high-performance solver (DeepSeekHashV1), ms-level response |
|
||||
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
|
||||
| Admin API | Config management, runtime settings hot-reload, account testing/batch test, session cleanup, import/export, Vercel sync, version check |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
||||
| Admin API | Config management, runtime settings hot-reload, proxy management, account testing/batch test, session cleanup, import/export, Vercel sync, version check |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode, with server-side conversation history) |
|
||||
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
||||
|
||||
## Platform Compatibility Matrix
|
||||
@@ -118,33 +108,42 @@ For the full module-by-module architecture and directory responsibilities, see [
|
||||
|
||||
## Model Support
|
||||
|
||||
### OpenAI Endpoint
|
||||
### OpenAI Endpoint (`GET /v1/models`)
|
||||
|
||||
| Model | thinking | search |
|
||||
| --- | --- | --- |
|
||||
| `deepseek-chat` | ❌ | ❌ |
|
||||
| `deepseek-reasoner` | ✅ | ❌ |
|
||||
| `deepseek-chat-search` | ❌ | ✅ |
|
||||
| `deepseek-reasoner-search` | ✅ | ✅ |
|
||||
| Family | Model ID | thinking | search |
|
||||
| --- | --- | --- | --- |
|
||||
| default | `deepseek-chat` | ❌ | ❌ |
|
||||
| default | `deepseek-reasoner` | ✅ | ❌ |
|
||||
| default | `deepseek-chat-search` | ❌ | ✅ |
|
||||
| default | `deepseek-reasoner-search` | ✅ | ✅ |
|
||||
| expert | `deepseek-expert-chat` | ❌ | ❌ |
|
||||
| expert | `deepseek-expert-reasoner` | ✅ | ❌ |
|
||||
| expert | `deepseek-expert-chat-search` | ❌ | ✅ |
|
||||
| expert | `deepseek-expert-reasoner-search` | ✅ | ✅ |
|
||||
| vision | `deepseek-vision-chat` | ❌ | ❌ |
|
||||
| vision | `deepseek-vision-reasoner` | ✅ | ❌ |
|
||||
| vision | `deepseek-vision-chat-search` | ❌ | ✅ |
|
||||
| vision | `deepseek-vision-reasoner-search` | ✅ | ✅ |
|
||||
|
||||
### Claude Endpoint
|
||||
Besides native IDs, DS2API also accepts common aliases as input (for example `gpt-5`, `gpt-5-mini`, `gpt-5-codex`, `gpt-4.1`, `o3`, `claude-opus-4-6`, `claude-sonnet-4-5`, `gemini-2.5-pro`, `gemini-2.5-flash`), but `/v1/models` returns normalized DeepSeek native model IDs.
|
||||
|
||||
| Model | Default Mapping |
|
||||
### Claude Endpoint (`GET /anthropic/v1/models`)
|
||||
|
||||
| Current common model | Default Mapping |
|
||||
| --- | --- |
|
||||
| `claude-sonnet-4-5` | `deepseek-chat` |
|
||||
| `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`) | `deepseek-chat` |
|
||||
| `claude-opus-4-6` | `deepseek-reasoner` |
|
||||
|
||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||
|
||||
Besides the current primary aliases above, `/anthropic/v1/models` also returns Claude 4.x snapshots plus historical 3.x / 2.x / 1.x IDs and common aliases for legacy client compatibility.
|
||||
|
||||
#### Claude Code integration pitfalls (validated)
|
||||
|
||||
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||
- If tool calls are rendered as plain text and not executed, first verify the model output uses supported XML/Markup tool blocks (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use`) rather than standalone JSON `tool_calls`.
|
||||
|
||||
### Gemini Endpoint
|
||||
|
||||
@@ -152,6 +151,15 @@ The Gemini adapter maps model names to DeepSeek native models via `model_aliases
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Recommended deployment priority
|
||||
|
||||
Recommended order when choosing a deployment method:
|
||||
|
||||
1. **Download and run release binaries**: the easiest path for most users because the artifacts are already built.
|
||||
2. **Docker / GHCR image deployment**: suitable for containerized, orchestrated, or cloud environments.
|
||||
3. **Vercel deployment**: suitable if you already use Vercel and accept its platform constraints.
|
||||
4. **Run from source / build locally**: suitable for development, debugging, or when you need to modify the code yourself.
|
||||
|
||||
### Universal First Step (all deployment modes)
|
||||
|
||||
Use `config.json` as the single source of truth (recommended):
|
||||
@@ -165,47 +173,39 @@ Recommended per deployment mode:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`, or paste raw JSON directly
|
||||
|
||||
### Option 1: Local Run
|
||||
The WebUI admin panel’s “Full configuration template” is loaded from the same `config.example.json`, so updating that file keeps the frontend template in sync.
|
||||
|
||||
**Prerequisites**: Go 1.26+, Node.js `20.19+` or `22.12+` (only if building WebUI locally)
|
||||
### Option 1: Download Release Binaries
|
||||
|
||||
GitHub Actions automatically builds multi-platform archives on each Release:
|
||||
|
||||
```bash
|
||||
# 1. Clone
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 2. Configure
|
||||
# After downloading the archive for your platform
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
cp config.example.json config.json
|
||||
# Edit config.json with your DeepSeek account info and API keys
|
||||
|
||||
# 3. Start
|
||||
go run ./cmd/ds2api
|
||||
# Edit config.json
|
||||
./ds2api
|
||||
```
|
||||
|
||||
Default local URL: `http://127.0.0.1:5001`
|
||||
|
||||
The server actually binds to `0.0.0.0:5001`, so devices on the same LAN can usually reach it through your private IP as well.
|
||||
|
||||
> **WebUI auto-build**: On first local startup, if `static/admin` is missing, DS2API will auto-run `npm ci` (only when dependencies are missing) and `npm run build -- --outDir static/admin --emptyOutDir` (requires Node.js). You can also build manually: `./scripts/build-webui.sh`
|
||||
|
||||
### Option 2: Docker
|
||||
### Option 2: Docker / GHCR
|
||||
|
||||
```bash
|
||||
# 1. Prepare env file and config file
|
||||
# Pull prebuilt image
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# Or run a pinned version
|
||||
# docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
|
||||
# Prepare env file and config file
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
|
||||
# 2. Edit .env (at least set DS2API_ADMIN_KEY; optionally set DS2API_HOST_PORT to change the host port)
|
||||
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
|
||||
# 3. Start
|
||||
# Start with compose
|
||||
docker-compose up -d
|
||||
|
||||
# 4. View logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
The default `docker-compose.yml` maps host port `6011` to container port `5001`. If you want `5001` exposed directly, set `DS2API_HOST_PORT=5001` (or adjust the `ports` mapping).
|
||||
The default `docker-compose.yml` uses `ghcr.io/cjackhwang/ds2api:latest` and maps host port `6011` to container port `5001`. If you want `5001` exposed directly, set `DS2API_HOST_PORT=5001` (or adjust the `ports` mapping).
|
||||
|
||||
Rebuild after updates: `docker-compose up -d --build`
|
||||
|
||||
@@ -241,35 +241,28 @@ base64 < config.json | tr -d '\n'
|
||||
|
||||
For detailed deployment instructions, see the [Deployment Guide](docs/DEPLOY.en.md).
|
||||
|
||||
### Option 4: Download Release Binaries
|
||||
### Option 4: Local Run
|
||||
|
||||
GitHub Actions automatically builds multi-platform archives on each Release:
|
||||
**Prerequisites**: Go 1.26+, Node.js `20.19+` or `22.12+` (only if building WebUI locally)
|
||||
|
||||
```bash
|
||||
# After downloading the archive for your platform
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
# 1. Clone
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 2. Configure
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
./ds2api
|
||||
# Edit config.json with your DeepSeek account info and API keys
|
||||
|
||||
# 3. Start
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### Option 5: OpenCode CLI
|
||||
Default local URL: `http://127.0.0.1:5001`
|
||||
|
||||
1. Copy the example config:
|
||||
The server actually binds to `0.0.0.0:5001`, so devices on the same LAN can usually reach it through your private IP as well.
|
||||
|
||||
```bash
|
||||
cp opencode.json.example opencode.json
|
||||
```
|
||||
|
||||
2. Edit `opencode.json`:
|
||||
- Set `baseURL` to your DS2API endpoint (for example, `https://your-domain.com/v1`)
|
||||
- Set `apiKey` to your DS2API key (from `config.keys`)
|
||||
|
||||
3. Start OpenCode CLI in the project directory (run `opencode` using your installed method).
|
||||
|
||||
> Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example.
|
||||
> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths.
|
||||
> **WebUI auto-build**: On first local startup, if `static/admin` is missing, DS2API will auto-run `npm ci` (only when dependencies are missing) and `npm run build -- --outDir static/admin --emptyOutDir` (requires Node.js). You can also build manually: `./scripts/build-webui.sh`
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -290,8 +283,12 @@ cp opencode.json.example opencode.json
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5": "deepseek-chat",
|
||||
"gpt-5-mini": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
"o3": "deepseek-reasoner",
|
||||
"claude-opus-4-6": "deepseek-reasoner",
|
||||
"gemini-2.5-flash": "deepseek-chat"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true,
|
||||
@@ -347,6 +344,7 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_JWT_EXPIRE_HOURS` | Admin JWT TTL in hours | `24` |
|
||||
| `DS2API_CONFIG_PATH` | Config file path | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | Inline config (JSON or Base64) | — |
|
||||
| `DS2API_CHAT_HISTORY_PATH` | Server-side conversation history file path | `data/chat_history.json` |
|
||||
| `DS2API_ENV_WRITEBACK` | Auto-write env-backed config to file and transition to file mode (`1/true/yes/on`) | Disabled |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | Admin static assets dir | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | Auto-build WebUI on startup | Enabled locally, disabled on Vercel |
|
||||
@@ -395,7 +393,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
2. The parser prioritizes XML/Markup, while also accepting JSON / ANTML / invoke / text-kv, and normalizes everything into the internal tool-call structure
|
||||
2. The parser currently targets XML/Markup-family tool syntax (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml variants); standalone JSON `tool_calls` payloads are not treated as executable calls by default
|
||||
3. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||
5. The output protocol follows the client request (OpenAI / Claude / Gemini native shapes); model-side prompting can prefer XML, and the compatibility layer handles the protocol-specific translation
|
||||
@@ -475,7 +473,7 @@ Workflow: `.github/workflows/release-artifacts.yml`
|
||||
- **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds)
|
||||
- **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), config template, README, LICENSE
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), `config.example.json`-based config template, README, LICENSE
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
@@ -5,14 +5,29 @@
|
||||
"your-api-key-1",
|
||||
"your-api-key-2"
|
||||
],
|
||||
"api_keys": [
|
||||
{
|
||||
"key": "your-api-key-1",
|
||||
"name": "主 API Key",
|
||||
"remark": "给 OpenAI 客户端使用"
|
||||
},
|
||||
{
|
||||
"key": "your-api-key-2",
|
||||
"name": "备用 API Key",
|
||||
"remark": "压测或临时调试"
|
||||
}
|
||||
],
|
||||
"accounts": [
|
||||
{
|
||||
"_comment": "邮箱登录方式",
|
||||
"name": "主账号",
|
||||
"remark": "优先用于生产流量",
|
||||
"email": "example1@example.com",
|
||||
"password": "your-password-1"
|
||||
},
|
||||
{
|
||||
"_comment": "邮箱登录方式 - 账号2",
|
||||
"name": "备用账号",
|
||||
"email": "example2@example.com",
|
||||
"password": "your-password-2"
|
||||
},
|
||||
@@ -34,6 +49,10 @@
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"history_split": {
|
||||
"enabled": true,
|
||||
"trigger_after_turns": 1
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
|
||||
@@ -116,7 +116,7 @@ flowchart LR
|
||||
- `internal/translatorcliproxy`: structure translation between Claude/Gemini and OpenAI.
|
||||
- `internal/deepseek`: upstream request/session/PoW/SSE handling.
|
||||
- `internal/stream` + `internal/sse`: stream parsing and incremental assembly.
|
||||
- `internal/toolcall`: JSON/XML/invoke/text-kv tool-call parsing + anti-leak sieve.
|
||||
- `internal/toolcall`: XML/Markup-family tool-call parsing + anti-leak sieve (`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml variants).
|
||||
- `internal/admin`: config/accounts/vercel sync/version/dev-capture endpoints.
|
||||
- `internal/config`: config loading/validation + runtime settings hot-reload.
|
||||
- `internal/account`: managed account pool, inflight slots, waiting queue.
|
||||
|
||||
@@ -116,7 +116,7 @@ flowchart LR
|
||||
- `internal/translatorcliproxy`:Claude/Gemini 与 OpenAI 结构互转。
|
||||
- `internal/deepseek`:上游请求、会话、PoW、SSE 消费。
|
||||
- `internal/stream` + `internal/sse`:流式解析与增量处理。
|
||||
- `internal/toolcall`:JSON/XML/invoke/text-kv 工具调用解析及防泄漏筛分。
|
||||
- `internal/toolcall`:以 XML/Markup 家族为核心的工具调用解析与防泄漏筛分(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体)。
|
||||
- `internal/admin`:配置管理、账号管理、Vercel 同步、版本检查、开发抓包。
|
||||
- `internal/config`:配置加载、校验、运行时 settings 热更新。
|
||||
- `internal/account`:托管账号池、并发槽位、等待队列。
|
||||
|
||||
@@ -10,11 +10,12 @@ Doc map: [Index](./README.md) | [Architecture](./ARCHITECTURE.en.md) | [API](../
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Recommended deployment priority](#recommended-deployment-priority)
|
||||
- [Prerequisites](#0-prerequisites)
|
||||
- [1. Local Run](#1-local-run)
|
||||
- [2. Docker Deployment](#2-docker-deployment)
|
||||
- [1. Download Release Binaries](#1-download-release-binaries)
|
||||
- [2. Docker / GHCR Deployment](#2-docker--ghcr-deployment)
|
||||
- [3. Vercel Deployment](#3-vercel-deployment)
|
||||
- [4. Download Release Binaries](#4-download-release-binaries)
|
||||
- [4. Local Run from Source](#4-local-run-from-source)
|
||||
- [5. Reverse Proxy (Nginx)](#5-reverse-proxy-nginx)
|
||||
- [6. Linux systemd Service](#6-linux-systemd-service)
|
||||
- [7. Post-Deploy Checks](#7-post-deploy-checks)
|
||||
@@ -22,6 +23,17 @@ Doc map: [Index](./README.md) | [Architecture](./ARCHITECTURE.en.md) | [API](../
|
||||
|
||||
---
|
||||
|
||||
## Recommended deployment priority
|
||||
|
||||
Recommended order when choosing a deployment method:
|
||||
|
||||
1. **Download and run release binaries**: the easiest path for most users because the artifacts are already built.
|
||||
2. **Docker / GHCR image deployment**: suitable for containerized, orchestrated, or cloud environments.
|
||||
3. **Vercel deployment**: suitable if you already use Vercel and accept its platform constraints.
|
||||
4. **Run from source / build locally**: suitable for development, debugging, or when you need to modify the code yourself.
|
||||
|
||||
---
|
||||
|
||||
## 0. Prerequisites
|
||||
|
||||
| Dependency | Minimum Version | Notes |
|
||||
@@ -48,70 +60,59 @@ Use `config.json` as the single source of truth:
|
||||
|
||||
---
|
||||
|
||||
## 1. Local Run
|
||||
## 1. Download Release Binaries
|
||||
|
||||
### 1.1 Basic Steps
|
||||
Built-in GitHub Actions workflow: `.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **Trigger**: only on Release `published` (no build on normal push)
|
||||
- **Outputs**: multi-platform binary archives + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
|
||||
| Platform | Architecture | Format |
|
||||
| --- | --- | --- |
|
||||
| Linux | amd64, arm64 | `.tar.gz` |
|
||||
| macOS | amd64, arm64 | `.tar.gz` |
|
||||
| Windows | amd64 | `.zip` |
|
||||
|
||||
Each archive includes:
|
||||
|
||||
- `ds2api` executable (`ds2api.exe` on Windows)
|
||||
- `static/admin/` (built WebUI assets)
|
||||
- `config.example.json`, `.env.example`
|
||||
- `README.MD`, `README.en.md`, `LICENSE`
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Clone
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
# 1. Download the archive for your platform
|
||||
# 2. Extract
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
|
||||
# Copy and edit config
|
||||
# 3. Configure
|
||||
cp config.example.json config.json
|
||||
# Open config.json and fill in:
|
||||
# - keys: your API access keys
|
||||
# - accounts: DeepSeek accounts (email or mobile + password)
|
||||
# Edit config.json
|
||||
|
||||
# Start
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
Default local access URL: `http://127.0.0.1:5001`; the server actually binds to `0.0.0.0:5001` (override with `PORT`).
|
||||
|
||||
### 1.2 WebUI Build
|
||||
|
||||
On first local startup, if `static/admin/` is missing, DS2API will automatically attempt to build the WebUI (requires Node.js/npm; when dependencies are missing it runs `npm ci` first, then `npm run build -- --outDir static/admin --emptyOutDir`).
|
||||
|
||||
Manual build:
|
||||
|
||||
```bash
|
||||
./scripts/build-webui.sh
|
||||
```
|
||||
|
||||
Or step by step:
|
||||
|
||||
```bash
|
||||
cd webui
|
||||
npm install
|
||||
npm run build
|
||||
# Output goes to static/admin/
|
||||
```
|
||||
|
||||
Control auto-build via environment variable:
|
||||
|
||||
```bash
|
||||
# Disable auto-build
|
||||
DS2API_AUTO_BUILD_WEBUI=false go run ./cmd/ds2api
|
||||
|
||||
# Force enable auto-build
|
||||
DS2API_AUTO_BUILD_WEBUI=true go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### 1.3 Compile to Binary
|
||||
|
||||
```bash
|
||||
go build -o ds2api ./cmd/ds2api
|
||||
# 4. Start
|
||||
./ds2api
|
||||
```
|
||||
|
||||
### Maintainer Release Flow
|
||||
|
||||
1. Create and publish a GitHub Release (with tag, for example `vX.Y.Z`)
|
||||
2. Wait for the `Release Artifacts` workflow to complete
|
||||
3. Download the matching archive from Release Assets
|
||||
|
||||
---
|
||||
|
||||
## 2. Docker Deployment
|
||||
## 2. Docker / GHCR Deployment
|
||||
|
||||
### 2.1 Basic Steps
|
||||
|
||||
```bash
|
||||
# Pull prebuilt image
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# Copy env template and config file
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
@@ -128,7 +129,13 @@ docker-compose up -d
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
The default `docker-compose.yml` maps host port `6011` to container port `5001`. If you want `5001` exposed directly, set `DS2API_HOST_PORT=5001` (or adjust the `ports` mapping).
|
||||
The default `docker-compose.yml` directly uses `ghcr.io/cjackhwang/ds2api:latest` and maps host port `6011` to container port `5001`. If you want `5001` exposed directly, set `DS2API_HOST_PORT=5001` (or adjust the `ports` mapping).
|
||||
|
||||
If you want a pinned version instead of `latest`, you can also pull a specific tag directly:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
```
|
||||
|
||||
### 2.2 Update
|
||||
|
||||
@@ -350,57 +357,61 @@ If API responses return Vercel HTML `Authentication Required`:
|
||||
|
||||
---
|
||||
|
||||
## 4. Download Release Binaries
|
||||
## 4. Local Run from Source
|
||||
|
||||
Built-in GitHub Actions workflow: `.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **Trigger**: only on Release `published` (no build on normal push)
|
||||
- **Outputs**: multi-platform binary archives + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
|
||||
| Platform | Architecture | Format |
|
||||
| --- | --- | --- |
|
||||
| Linux | amd64, arm64 | `.tar.gz` |
|
||||
| macOS | amd64, arm64 | `.tar.gz` |
|
||||
| Windows | amd64 | `.zip` |
|
||||
|
||||
Each archive includes:
|
||||
|
||||
- `ds2api` executable (`ds2api.exe` on Windows)
|
||||
- `static/admin/` (built WebUI assets)
|
||||
- `config.example.json`, `.env.example`
|
||||
- `README.MD`, `README.en.md`, `LICENSE`
|
||||
|
||||
### Usage
|
||||
### 4.1 Basic Steps
|
||||
|
||||
```bash
|
||||
# 1. Download the archive for your platform
|
||||
# 2. Extract
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
# Clone
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 3. Configure
|
||||
# Copy and edit config
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
# Open config.json and fill in:
|
||||
# - keys: your API access keys
|
||||
# - accounts: DeepSeek accounts (email or mobile + password)
|
||||
|
||||
# 4. Start
|
||||
./ds2api
|
||||
# Start
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### Maintainer Release Flow
|
||||
Default local access URL: `http://127.0.0.1:5001`; the server actually binds to `0.0.0.0:5001` (override with `PORT`).
|
||||
|
||||
1. Create and publish a GitHub Release (with tag, for example `vX.Y.Z`)
|
||||
2. Wait for the `Release Artifacts` workflow to complete
|
||||
3. Download the matching archive from Release Assets
|
||||
### 4.2 WebUI Build
|
||||
|
||||
### Pull from GHCR (Optional)
|
||||
On first local startup, if `static/admin/` is missing, DS2API will automatically attempt to build the WebUI (requires Node.js/npm; when dependencies are missing it runs `npm ci` first, then `npm run build -- --outDir static/admin --emptyOutDir`).
|
||||
|
||||
Manual build:
|
||||
|
||||
```bash
|
||||
# latest
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
./scripts/build-webui.sh
|
||||
```
|
||||
|
||||
# specific version (example)
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
Or step by step:
|
||||
|
||||
```bash
|
||||
cd webui
|
||||
npm install
|
||||
npm run build
|
||||
# Output goes to static/admin/
|
||||
```
|
||||
|
||||
Control auto-build via environment variable:
|
||||
|
||||
```bash
|
||||
# Disable auto-build
|
||||
DS2API_AUTO_BUILD_WEBUI=false go run ./cmd/ds2api
|
||||
|
||||
# Force enable auto-build
|
||||
DS2API_AUTO_BUILD_WEBUI=true go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### 4.3 Compile to Binary
|
||||
|
||||
```bash
|
||||
go build -o ds2api ./cmd/ds2api
|
||||
./ds2api
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
205
docs/DEPLOY.md
205
docs/DEPLOY.md
@@ -10,11 +10,12 @@
|
||||
|
||||
## 目录
|
||||
|
||||
- [部署方式优先级建议](#部署方式优先级建议)
|
||||
- [前置要求](#0-前置要求)
|
||||
- [一、本地运行](#一本地运行)
|
||||
- [二、Docker 部署](#二docker-部署)
|
||||
- [一、下载 Release 构建包](#一下载-release-构建包)
|
||||
- [二、Docker / GHCR 部署](#二docker--ghcr-部署)
|
||||
- [三、Vercel 部署](#三vercel-部署)
|
||||
- [四、下载 Release 构建包](#四下载-release-构建包)
|
||||
- [四、本地源码运行](#四本地源码运行)
|
||||
- [五、反向代理(Nginx)](#五反向代理nginx)
|
||||
- [六、Linux systemd 服务化](#六linux-systemd-服务化)
|
||||
- [七、部署后检查](#七部署后检查)
|
||||
@@ -22,6 +23,17 @@
|
||||
|
||||
---
|
||||
|
||||
## 部署方式优先级建议
|
||||
|
||||
推荐按以下顺序选择部署方式:
|
||||
|
||||
1. **下载 Release 构建包运行**:最省事,产物已编译完成,最适合大多数用户。
|
||||
2. **Docker / GHCR 镜像部署**:适合需要容器化、编排或云环境部署。
|
||||
3. **Vercel 部署**:适合已有 Vercel 环境且接受其平台约束的场景。
|
||||
4. **本地源码运行 / 自行编译**:适合开发、调试或需要自行修改代码的场景。
|
||||
|
||||
---
|
||||
|
||||
## 0. 前置要求
|
||||
|
||||
| 依赖 | 最低版本 | 说明 |
|
||||
@@ -48,70 +60,59 @@ cp config.example.json config.json
|
||||
|
||||
---
|
||||
|
||||
## 一、本地运行
|
||||
## 一、下载 Release 构建包
|
||||
|
||||
### 1.1 基本步骤
|
||||
仓库内置 GitHub Actions 工作流:`.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **触发条件**:仅在 Release `published` 时触发(普通 push 不会构建)
|
||||
- **构建产物**:多平台二进制压缩包 + `sha256sums.txt`
|
||||
- **容器镜像发布**:仅发布到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
|
||||
| 平台 | 架构 | 文件格式 |
|
||||
| --- | --- | --- |
|
||||
| Linux | amd64, arm64 | `.tar.gz` |
|
||||
| macOS | amd64, arm64 | `.tar.gz` |
|
||||
| Windows | amd64 | `.zip` |
|
||||
|
||||
每个压缩包包含:
|
||||
|
||||
- `ds2api` 可执行文件(Windows 为 `ds2api.exe`)
|
||||
- `static/admin/`(WebUI 构建产物)
|
||||
- `config.example.json`、`.env.example`
|
||||
- `README.MD`、`README.en.md`、`LICENSE`
|
||||
|
||||
### 使用步骤
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
# 1. 下载对应平台的压缩包
|
||||
# 2. 解压
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
|
||||
# 复制并编辑配置
|
||||
# 3. 配置
|
||||
cp config.example.json config.json
|
||||
# 使用你喜欢的编辑器打开 config.json,填入:
|
||||
# - keys: 你的 API 访问密钥
|
||||
# - accounts: DeepSeek 账号(email 或 mobile + password)
|
||||
# 编辑 config.json
|
||||
|
||||
# 启动服务
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
默认本地访问地址是 `http://127.0.0.1:5001`;服务实际绑定 `0.0.0.0:5001`,可通过 `PORT` 环境变量覆盖。
|
||||
|
||||
### 1.2 WebUI 构建
|
||||
|
||||
本地首次启动时,若 `static/admin/` 不存在,服务会自动尝试构建 WebUI(需要 Node.js/npm;缺依赖时会先执行 `npm ci`,再执行 `npm run build -- --outDir static/admin --emptyOutDir`)。
|
||||
|
||||
你也可以手动构建:
|
||||
|
||||
```bash
|
||||
./scripts/build-webui.sh
|
||||
```
|
||||
|
||||
或手动执行:
|
||||
|
||||
```bash
|
||||
cd webui
|
||||
npm install
|
||||
npm run build
|
||||
# 产物输出到 static/admin/
|
||||
```
|
||||
|
||||
通过环境变量控制自动构建行为:
|
||||
|
||||
```bash
|
||||
# 强制关闭自动构建
|
||||
DS2API_AUTO_BUILD_WEBUI=false go run ./cmd/ds2api
|
||||
|
||||
# 强制开启自动构建
|
||||
DS2API_AUTO_BUILD_WEBUI=true go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### 1.3 编译为二进制文件
|
||||
|
||||
```bash
|
||||
go build -o ds2api ./cmd/ds2api
|
||||
# 4. 启动
|
||||
./ds2api
|
||||
```
|
||||
|
||||
### 维护者发布步骤
|
||||
|
||||
1. 在 GitHub 创建并发布 Release(带 tag,如 `vX.Y.Z`)
|
||||
2. 等待 Actions 工作流 `Release Artifacts` 完成
|
||||
3. 在 Release 的 Assets 下载对应平台压缩包
|
||||
|
||||
---
|
||||
|
||||
## 二、Docker 部署
|
||||
## 二、Docker / GHCR 部署
|
||||
|
||||
### 2.1 基本步骤
|
||||
|
||||
```bash
|
||||
# 拉取预编译镜像
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# 复制环境变量模板和配置文件
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
@@ -128,7 +129,13 @@ docker-compose up -d
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
默认 `docker-compose.yml` 会把宿主机 `6011` 映射到容器内的 `5001`。如果你希望直接对外暴露 `5001`,请设置 `DS2API_HOST_PORT=5001`(或者手动调整 `ports` 配置)。
|
||||
默认 `docker-compose.yml` 直接使用 `ghcr.io/cjackhwang/ds2api:latest`,并把宿主机 `6011` 映射到容器内的 `5001`。如果你希望直接对外暴露 `5001`,请设置 `DS2API_HOST_PORT=5001`(或者手动调整 `ports` 配置)。
|
||||
|
||||
如需固定版本,也可以直接拉取指定 tag:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
```
|
||||
|
||||
### 2.2 更新
|
||||
|
||||
@@ -251,12 +258,22 @@ VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局并发上限 | `recommended_concurrency` |
|
||||
| `DS2API_ENV_WRITEBACK` | 检测到 `DS2API_CONFIG_JSON` 时自动写入 `DS2API_CONFIG_PATH`,并在成功后转为文件模式(`1/true/yes/on`) | 关闭 |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | 默认与 `responses.store_ttl_seconds` 同步,若未设置则为 `900` |
|
||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
|
||||
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
||||
| `DS2API_VERCEL_PROTECTION_BYPASS` | 部署保护绕过密钥(内部 Node→Go 调用) | — |
|
||||
|
||||
### 3.3 运行时行为配置(通过 Admin API 设置)
|
||||
|
||||
部分运行时行为无法通过环境变量直接配置,需要在部署后通过 Admin API 设置,例如:
|
||||
|
||||
- **自动删除会话模式** (`auto_delete.mode`):支持 `none` / `single` / `all`,默认为 `none`。可通过 `PUT /admin/settings` 更新。
|
||||
- **每账号并发上限** (`account_max_inflight`):环境变量已支持,但也可通过 Admin API 热更新。
|
||||
- **全局并发上限** (`global_max_inflight`):同上。
|
||||
|
||||
详细说明参见 [API.md](../API.md#admin-接口) 中 `/admin/settings` 部分。
|
||||
|
||||
### 3.3 Vercel 架构说明
|
||||
|
||||
```text
|
||||
@@ -350,57 +367,61 @@ No Output Directory named "public" found after the Build completed.
|
||||
|
||||
---
|
||||
|
||||
## 四、下载 Release 构建包
|
||||
## 四、本地源码运行
|
||||
|
||||
仓库内置 GitHub Actions 工作流:`.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **触发条件**:仅在 Release `published` 时触发(普通 push 不会构建)
|
||||
- **构建产物**:多平台二进制压缩包 + `sha256sums.txt`
|
||||
- **容器镜像发布**:仅发布到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
|
||||
| 平台 | 架构 | 文件格式 |
|
||||
| --- | --- | --- |
|
||||
| Linux | amd64, arm64 | `.tar.gz` |
|
||||
| macOS | amd64, arm64 | `.tar.gz` |
|
||||
| Windows | amd64 | `.zip` |
|
||||
|
||||
每个压缩包包含:
|
||||
|
||||
- `ds2api` 可执行文件(Windows 为 `ds2api.exe`)
|
||||
- `static/admin/`(WebUI 构建产物)
|
||||
- `config.example.json`、`.env.example`
|
||||
- `README.MD`、`README.en.md`、`LICENSE`
|
||||
|
||||
### 使用步骤
|
||||
### 4.1 基本步骤
|
||||
|
||||
```bash
|
||||
# 1. 下载对应平台的压缩包
|
||||
# 2. 解压
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
# 克隆仓库
|
||||
git clone https://github.com/CJackHwang/ds2api.git
|
||||
cd ds2api
|
||||
|
||||
# 3. 配置
|
||||
# 复制并编辑配置
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
# 使用你喜欢的编辑器打开 config.json,填入:
|
||||
# - keys: 你的 API 访问密钥
|
||||
# - accounts: DeepSeek 账号(email 或 mobile + password)
|
||||
|
||||
# 4. 启动
|
||||
./ds2api
|
||||
# 启动服务
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### 维护者发布步骤
|
||||
默认本地访问地址是 `http://127.0.0.1:5001`;服务实际绑定 `0.0.0.0:5001`,可通过 `PORT` 环境变量覆盖。
|
||||
|
||||
1. 在 GitHub 创建并发布 Release(带 tag,如 `vX.Y.Z`)
|
||||
2. 等待 Actions 工作流 `Release Artifacts` 完成
|
||||
3. 在 Release 的 Assets 下载对应平台压缩包
|
||||
### 4.2 WebUI 构建
|
||||
|
||||
### 拉取 GHCR 镜像(可选)
|
||||
本地首次启动时,若 `static/admin/` 不存在,服务会自动尝试构建 WebUI(需要 Node.js/npm;缺依赖时会先执行 `npm ci`,再执行 `npm run build -- --outDir static/admin --emptyOutDir`)。
|
||||
|
||||
你也可以手动构建:
|
||||
|
||||
```bash
|
||||
# latest
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
./scripts/build-webui.sh
|
||||
```
|
||||
|
||||
# 指定版本(示例)
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
或手动执行:
|
||||
|
||||
```bash
|
||||
cd webui
|
||||
npm install
|
||||
npm run build
|
||||
# 产物输出到 static/admin/
|
||||
```
|
||||
|
||||
通过环境变量控制自动构建行为:
|
||||
|
||||
```bash
|
||||
# 强制关闭自动构建
|
||||
DS2API_AUTO_BUILD_WEBUI=false go run ./cmd/ds2api
|
||||
|
||||
# 强制开启自动构建
|
||||
DS2API_AUTO_BUILD_WEBUI=true go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
### 4.3 编译为二进制文件
|
||||
|
||||
```bash
|
||||
go build -o ds2api ./cmd/ds2api
|
||||
./ds2api
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -1,74 +1,74 @@
|
||||
# Tool call parsing semantics(Go/Node 统一语义)
|
||||
|
||||
本文档描述当前代码中 `ParseToolCallsDetailed` / `parseToolCallsDetailed` 的**实际行为**,用于对齐 Go 与 Node Runtime。
|
||||
本文档描述当前代码中工具调用解析链路的**实际行为**(以 `internal/toolcall` 与 `internal/js/helpers/stream-tool-sieve` 为准)。
|
||||
|
||||
文档导航:[总览](../README.MD) / [架构说明](./ARCHITECTURE.md) / [测试指南](./TESTING.md)
|
||||
|
||||
## 1) 输出结构(当前实现)
|
||||
## 1) 当前输出结构
|
||||
|
||||
- `calls`:解析得到的工具调用列表(`name` + `input`)。
|
||||
- `sawToolCallSyntax`:检测到工具调用语法特征时为 `true`(例如 `tool_calls`、`<tool_call>`、`<function_call>`、`<invoke>`、`function.name:`)。
|
||||
- `rejectedByPolicy`:当前实现固定为 `false`(预留字段,尚未启用 allow-list 拒绝)。
|
||||
`ParseToolCallsDetailed` / `parseToolCallsDetailed` 返回:
|
||||
|
||||
- `calls`:解析出的工具调用列表(`name` + `input`)。
|
||||
- `sawToolCallSyntax`:检测到工具调用语法特征时为 `true`。
|
||||
- `rejectedByPolicy`:当前实现固定为 `false`(预留字段)。
|
||||
- `rejectedToolNames`:当前实现固定为空数组(预留字段)。
|
||||
|
||||
> 说明:`filterToolCallsDetailed` 当前仅做结构清洗,不做工具名策略拒绝。
|
||||
> 当前 `filterToolCallsDetailed` 仅做结构清洗,不做 allow-list 工具名硬拒绝。
|
||||
|
||||
## 2) 解析管线
|
||||
## 2) 解析范围(重点)
|
||||
|
||||
1. **示例保护**:若判定为 fenced code block 示例上下文,则跳过执行型解析。
|
||||
2. **候选片段构建**:从完整文本中构建候选(原文、围绕 `tool_calls` 的 JSON 片段、首尾大括号切片等)。
|
||||
3. **按序尝试解析(命中即停)**:
|
||||
- 对“明显 JSON 工具载荷候选”(以 `{`/`[` 开头且包含 `tool_calls`/`\"function\"`)先走 JSON 解析,避免 JSON 字符串内偶发 XML 片段误命中;
|
||||
- 其余候选优先 XML 解析(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / `antml:function_call` 等);
|
||||
- JSON 解析(`{"tool_calls": [...]}`、列表、单对象);
|
||||
- Markup 解析;
|
||||
- Text-KV 回退(如 `function.name:` + `function.arguments:`)。
|
||||
4. **兜底**:候选全部失败后,再对全文做 XML / Text-KV 回退。
|
||||
当前版本的可执行解析以 **XML/Markup 家族**为主:
|
||||
|
||||
## 3) XML 能力边界(当前)
|
||||
- `<tool_call>...</tool_call>`
|
||||
- `<function_call>...</function_call>`
|
||||
- `<invoke ...>...</invoke>`(含自闭合)
|
||||
- `<tool_use>...</tool_use>`
|
||||
- antml 变体(如 `antml:function_call` / `antml:argument`)
|
||||
|
||||
当前已支持输入端的“多 XML/标记风格”解析,包括但不限于:
|
||||
并支持在这些标记块内部解析:
|
||||
|
||||
- `<tool_call><tool_name>...</tool_name><parameters>...</parameters></tool_call>`
|
||||
- `<function_call>tool</function_call><function parameter name="x">...</function parameter>`
|
||||
- `<invoke name="tool"><parameter name="x">...</parameter></invoke>`
|
||||
- `antml:function_call` / `antml:argument` / `antml:parameters`
|
||||
- `tool_use` 家族标签
|
||||
- JSON 参数字符串
|
||||
- 标签参数(`<parameter name="...">...`)
|
||||
- key/value 风格子标签
|
||||
|
||||
但**输出端仍统一转换为 OpenAI 兼容 JSON 事件/对象**(`message.tool_calls`、`delta.tool_calls`、`response.function_call_arguments.*`)。
|
||||
## 3) 不应再假设的行为
|
||||
|
||||
## 4) 关于“是否可以封装成 XML 再喂给模型”
|
||||
以下说法在当前实现中已不成立:
|
||||
|
||||
结论:**可以做,而且当前解析器已经能兼容 XML 作为输入格式之一**,但代码里并没有 `toolcall.prefer_xml_output` 这个开关。现有可调配置只有:
|
||||
1. “纯 JSON `tool_calls` 片段会被直接当作可执行工具调用解析”。
|
||||
2. “存在 `toolcall.mode` / `toolcall.early_emit_confidence` 等可配置开关可以改变解析策略”。
|
||||
|
||||
- `toolcall.mode`:`feature_match` / `off`
|
||||
- `toolcall.early_emit_confidence`:`high` / `low` / `off`
|
||||
当前策略在代码中固定为:
|
||||
|
||||
推荐思路仍然是“输入兼容层 + 输出按客户端协议渲染”:
|
||||
- 特征匹配开启(feature-match on)
|
||||
- 高置信度早发开启(early emit on)
|
||||
- policy 拒绝字段保留但未启用
|
||||
|
||||
1. **Prompt 约束层**:如果你要尝试 XML-first,可以在系统提示词里约束模型输出规范 XML tool block(例如 `<tool_calls><tool_call>...</tool_call></tool_calls>`)。
|
||||
2. **解析兼容层**:继续在 parser 中同时接受 JSON / XML / ANTML / invoke / text-kv。
|
||||
3. **协议归一层**:无论模型输出什么格式,统一落到内部 `ParsedToolCall`。
|
||||
4. **对外渲染层**:根据客户端请求协议渲染(OpenAI / Claude / Gemini 各自格式)。
|
||||
## 4) 流式与防泄漏语义
|
||||
|
||||
这样可以同时获得:
|
||||
在流式链路中(OpenAI / Claude / Gemini 统一内核):
|
||||
|
||||
- 减少模型端 JSON 转义/引号错误;
|
||||
- 不破坏现有 SDK / 客户端生态;
|
||||
- 逐步灰度(按模型、按租户、按请求开关)。
|
||||
- 工具调用片段会被优先提取为结构化增量输出;
|
||||
- 已识别的工具调用原始片段不会作为普通文本再次回流;
|
||||
- fenced code block 中的示例内容按文本处理,不作为可执行工具调用。
|
||||
|
||||
## 5) 落地建议(低风险迭代)
|
||||
## 5) 落地建议(按当前实现)
|
||||
|
||||
- 继续使用现有的 `toolcall.mode=feature_match` 和 `toolcall.early_emit_confidence=high` 作为默认策略。
|
||||
- 如果要试 XML-first,把它放在 prompt 层或上游模板层,不要假设代码里已有专门的 XML 输出开关。
|
||||
- 增加观测指标:
|
||||
- `toolcall_parse_source`(json/xml/markup/textkv);
|
||||
- `toolcall_parse_success_rate`;
|
||||
- `toolcall_malformed_rate`;
|
||||
- `toolcall_repair_rate`。
|
||||
- 先在 `responses` 链路灰度,再扩展 `chat.completions`。
|
||||
1. Prompt 里优先约束模型输出 XML/Markup 工具块。
|
||||
2. 执行器侧继续做工具名白名单与参数 schema 校验(不要依赖 parser 代替安全策略)。
|
||||
3. 需要兼容历史“纯 JSON tool_calls”模型输出时,请在上游模板层把输出规范化为 XML/Markup 风格再进入 DS2API。
|
||||
|
||||
## 6) 兼容性提醒
|
||||
## 6) 回归验证建议
|
||||
|
||||
- 上游模型若输出混合文本 + XML,仍可能出现“半结构化”噪声,需要依赖现有 sieve 增量消费策略。
|
||||
- XML 不等于安全:仍需做 tool 名、参数 schema、执行权限的服务端校验。
|
||||
可直接运行:
|
||||
|
||||
```bash
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/toolcall/
|
||||
node --test tests/node/stream-tool-sieve.test.js
|
||||
```
|
||||
|
||||
重点覆盖:
|
||||
|
||||
- `<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / antml 变体
|
||||
- 参数 JSON 修复与解析
|
||||
- 流式增量下的工具调用提取与文本防泄漏
|
||||
|
||||
2
go.mod
2
go.mod
@@ -18,7 +18,7 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/net v0.52.0
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -138,77 +138,6 @@ func TestHandleClaudeStreamRealtimeThinkingDelta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["type"] == "text_delta" && strings.Contains(asString(delta["text"]), `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in text delta: body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block in stream, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundToolUseStop := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
foundToolUseStop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUseStop {
|
||||
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolDetectionFromThinkingFallback(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "search" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block from thinking fallback, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
if !containsStr(content, "<tool_calls>") || !containsStr(content, "<tool_name>search_web</tool_name>") {
|
||||
t.Fatalf("expected assistant content to include XML tool call history, got %q", content)
|
||||
}
|
||||
if !containsStr(content, `<parameters>{"query":"latest"}</parameters>`) {
|
||||
if !containsStr(content, "<parameters>\n <query><![CDATA[latest]]></query>\n </parameters>") {
|
||||
t.Fatalf("expected assistant content to include serialized parameters, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,11 +34,13 @@ func (s openAIProxyStub) ChatCompletions(w http.ResponseWriter, _ *http.Request)
|
||||
|
||||
type openAIProxyCaptureStub struct {
|
||||
seenModel string
|
||||
seenReq map[string]any
|
||||
}
|
||||
|
||||
func (s *openAIProxyCaptureStub) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
s.seenReq = req
|
||||
if m, ok := req["model"].(string); ok {
|
||||
s.seenModel = m
|
||||
}
|
||||
@@ -84,3 +86,33 @@ func TestClaudeProxyViaOpenAIPreservesClaudeMapping(t *testing.T) {
|
||||
t.Fatalf("expected mapped proxy model deepseek-reasoner, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyTranslatesInlineImageToOpenAIDataURL(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{OpenAI: openAI}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"QUJDRA=="}}]}],"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
messages, _ := openAI.seenReq["messages"].([]any)
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected one translated message, got %#v", openAI.seenReq)
|
||||
}
|
||||
msg, _ := messages[0].(map[string]any)
|
||||
content, _ := msg["content"].([]any)
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected translated content blocks, got %#v", msg)
|
||||
}
|
||||
imageBlock, _ := content[1].(map[string]any)
|
||||
if strings.TrimSpace(asString(imageBlock["type"])) != "image_url" {
|
||||
t.Fatalf("expected image_url block, got %#v", imageBlock)
|
||||
}
|
||||
imageURL, _ := imageBlock["image_url"].(map[string]any)
|
||||
if !strings.HasPrefix(strings.TrimSpace(asString(imageURL["url"])), "data:image/png;base64,") {
|
||||
t.Fatalf("expected translated data url, got %#v", imageBlock)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
finalPrompt := deepseek.MessagesPrepareWithThinking(toMessageMaps(dsPayload["messages"]), thinkingEnabled)
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if len(toolNames) == 0 && len(toolsRequested) > 0 {
|
||||
toolNames = []string{"__any_tool__"}
|
||||
|
||||
@@ -28,7 +28,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
|
||||
}
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
|
||||
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled)
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
|
||||
@@ -82,11 +82,17 @@ func (s geminiOpenAIErrorStub) ChatCompletions(w http.ResponseWriter, _ *http.Re
|
||||
}
|
||||
|
||||
type geminiOpenAISuccessStub struct {
|
||||
stream bool
|
||||
body string
|
||||
stream bool
|
||||
body string
|
||||
seenReq map[string]any
|
||||
}
|
||||
|
||||
func (s geminiOpenAISuccessStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
func (s *geminiOpenAISuccessStub) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
if r != nil {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
s.seenReq = req
|
||||
}
|
||||
if s.stream {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -144,7 +150,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
OpenAI: geminiOpenAISuccessStub{
|
||||
OpenAI: &geminiOpenAISuccessStub{
|
||||
body: `{"id":"chatcmpl-1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"eval_javascript","arguments":"{\"code\":\"1+1\"}"}}]},"finish_reason":"tool_calls"}]}`,
|
||||
},
|
||||
}
|
||||
@@ -184,7 +190,7 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
h := &Handler{Store: testGeminiConfig{}, OpenAI: geminiOpenAISuccessStub{}}
|
||||
h := &Handler{Store: testGeminiConfig{}, OpenAI: &geminiOpenAISuccessStub{}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
@@ -217,7 +223,7 @@ func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
OpenAI: geminiOpenAISuccessStub{stream: true},
|
||||
OpenAI: &geminiOpenAISuccessStub{stream: true},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
@@ -251,6 +257,39 @@ func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProxyTranslatesInlineImageToOpenAIDataURL(t *testing.T) {
|
||||
openAI := &geminiOpenAISuccessStub{}
|
||||
h := &Handler{Store: testGeminiConfig{}, OpenAI: openAI}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{"contents":[{"role":"user","parts":[{"text":"hello"},{"inlineData":{"mimeType":"image/png","data":"QUJDRA=="}}]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
messages, _ := openAI.seenReq["messages"].([]any)
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected one translated message, got %#v", openAI.seenReq)
|
||||
}
|
||||
msg, _ := messages[0].(map[string]any)
|
||||
content, _ := msg["content"].([]any)
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected translated content blocks, got %#v", msg)
|
||||
}
|
||||
imageBlock, _ := content[1].(map[string]any)
|
||||
if strings.TrimSpace(asString(imageBlock["type"])) != "image_url" {
|
||||
t.Fatalf("expected image_url block, got %#v", imageBlock)
|
||||
}
|
||||
imageURL, _ := imageBlock["image_url"].(map[string]any)
|
||||
if !strings.HasPrefix(strings.TrimSpace(asString(imageURL["url"])), "data:image/png;base64,") {
|
||||
t.Fatalf("expected translated data url, got %#v", imageBlock)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) {
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
|
||||
250
internal/adapter/openai/chat_history.go
Normal file
250
internal/adapter/openai/chat_history.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/prompt"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
const adminWebUISourceHeader = "X-Ds2-Source"
|
||||
const adminWebUISourceValue = "admin-webui-api-tester"
|
||||
|
||||
type chatHistorySession struct {
|
||||
store *chathistory.Store
|
||||
entryID string
|
||||
startedAt time.Time
|
||||
lastPersist time.Time
|
||||
finalPrompt string
|
||||
startParams chathistory.StartParams
|
||||
disabled bool
|
||||
}
|
||||
|
||||
func startChatHistory(store *chathistory.Store, r *http.Request, a *auth.RequestAuth, stdReq util.StandardRequest) *chatHistorySession {
|
||||
if store == nil || r == nil || a == nil {
|
||||
return nil
|
||||
}
|
||||
if !store.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if !shouldCaptureChatHistory(r) {
|
||||
return nil
|
||||
}
|
||||
entry, err := store.Start(chathistory.StartParams{
|
||||
CallerID: strings.TrimSpace(a.CallerID),
|
||||
AccountID: strings.TrimSpace(a.AccountID),
|
||||
Model: strings.TrimSpace(stdReq.ResponseModel),
|
||||
Stream: stdReq.Stream,
|
||||
UserInput: extractSingleUserInput(stdReq.Messages),
|
||||
Messages: extractAllMessages(stdReq.Messages),
|
||||
FinalPrompt: stdReq.FinalPrompt,
|
||||
})
|
||||
startParams := chathistory.StartParams{
|
||||
CallerID: strings.TrimSpace(a.CallerID),
|
||||
AccountID: strings.TrimSpace(a.AccountID),
|
||||
Model: strings.TrimSpace(stdReq.ResponseModel),
|
||||
Stream: stdReq.Stream,
|
||||
UserInput: extractSingleUserInput(stdReq.Messages),
|
||||
Messages: extractAllMessages(stdReq.Messages),
|
||||
FinalPrompt: stdReq.FinalPrompt,
|
||||
}
|
||||
session := &chatHistorySession{
|
||||
store: store,
|
||||
entryID: entry.ID,
|
||||
startedAt: time.Now(),
|
||||
lastPersist: time.Now(),
|
||||
finalPrompt: stdReq.FinalPrompt,
|
||||
startParams: startParams,
|
||||
}
|
||||
if err != nil {
|
||||
if entry.ID == "" {
|
||||
config.Logger.Warn("[chat_history] start failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
config.Logger.Warn("[chat_history] start persisted in memory after write failure", "error", err)
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func shouldCaptureChatHistory(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) || isVercelStreamReleaseRequest(r) {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.Header.Get(adminWebUISourceHeader)) != adminWebUISourceValue
|
||||
}
|
||||
|
||||
func extractSingleUserInput(messages []any) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "user" {
|
||||
continue
|
||||
}
|
||||
if normalized := strings.TrimSpace(prompt.NormalizeContent(msg["content"])); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractAllMessages(messages []any) []chathistory.Message {
|
||||
out := make([]chathistory.Message, 0, len(messages))
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
content := strings.TrimSpace(prompt.NormalizeContent(msg["content"]))
|
||||
if role == "" || content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, chathistory.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) progress(thinking, content string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
if now.Sub(s.lastPersist) < 250*time.Millisecond {
|
||||
return
|
||||
}
|
||||
s.lastPersist = now
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "streaming",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) success(statusCode int, thinking, content, finishReason string, usage map[string]any) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
Completed: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) error(statusCode int, message, finishReason, thinking, content string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "error",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
Error: message,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Completed: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) stopped(thinking, content, finishReason string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
s.persistUpdate(chathistory.UpdateParams{
|
||||
Status: "stopped",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: openaifmt.BuildChatUsage(s.finalPrompt, thinking, content),
|
||||
Completed: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) retryMissingEntry() bool {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return false
|
||||
}
|
||||
entry, err := s.store.Start(s.startParams)
|
||||
if errors.Is(err, chathistory.ErrDisabled) {
|
||||
s.disabled = true
|
||||
return false
|
||||
}
|
||||
if entry.ID == "" {
|
||||
if err != nil {
|
||||
config.Logger.Warn("[chat_history] recreate missing entry failed", "error", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
s.entryID = entry.ID
|
||||
if err != nil {
|
||||
config.Logger.Warn("[chat_history] recreate missing entry persisted in memory after write failure", "error", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) persistUpdate(params chathistory.UpdateParams) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
if _, err := s.store.Update(s.entryID, params); err != nil {
|
||||
s.handlePersistError(params, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) handlePersistError(params chathistory.UpdateParams, err error) {
|
||||
if err == nil || s == nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, chathistory.ErrDisabled) {
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
if isChatHistoryMissingError(err) {
|
||||
if s.retryMissingEntry() {
|
||||
if _, retryErr := s.store.Update(s.entryID, params); retryErr != nil {
|
||||
if errors.Is(retryErr, chathistory.ErrDisabled) || isChatHistoryMissingError(retryErr) {
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
config.Logger.Warn("[chat_history] retry after missing entry failed", "error", retryErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
config.Logger.Warn("[chat_history] update failed", "error", err)
|
||||
}
|
||||
|
||||
func isChatHistoryMissingError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "not found")
|
||||
}
|
||||
273
internal/adapter/openai/chat_history_test.go
Normal file
273
internal/adapter/openai/chat_history_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func newTestChatHistoryStore(t *testing.T) *chathistory.Store {
|
||||
t.Helper()
|
||||
store := chathistory.New(filepath.Join(t.TempDir(), "chat_history.json"))
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("chat history store unavailable: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func blockChatHistoryDetailDir(t *testing.T, detailDir string) func() {
|
||||
t.Helper()
|
||||
blockedDir := detailDir + ".blocked"
|
||||
if err := os.RemoveAll(blockedDir); err != nil {
|
||||
t.Fatalf("remove blocked detail dir failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(detailDir, blockedDir); err != nil {
|
||||
t.Fatalf("move detail dir aside failed: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocked detail path failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(detailDir, []byte("blocked"), 0o644); err != nil {
|
||||
t.Fatalf("write blocked detail path failed: %v", err)
|
||||
}
|
||||
var once sync.Once
|
||||
return func() {
|
||||
t.Helper()
|
||||
once.Do(func() {
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocking detail path failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(blockedDir, detailDir); err != nil {
|
||||
t.Fatalf("restore detail dir failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsNonStreamPersistsHistory(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"system","content":"be precise"},{"role":"user","content":"hi there"},{"role":"assistant","content":"previous answer"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
item := snapshot.Items[0]
|
||||
if item.Status != "success" || item.UserInput != "hi there" {
|
||||
t.Fatalf("unexpected persisted history summary: %#v", item)
|
||||
}
|
||||
full, err := historyStore.Get(item.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected detail item, got %v", err)
|
||||
}
|
||||
if full.Content != "hello world" {
|
||||
t.Fatalf("expected detail content persisted, got %#v", full)
|
||||
}
|
||||
if len(full.Messages) != 3 {
|
||||
t.Fatalf("expected all request messages persisted, got %#v", full.Messages)
|
||||
}
|
||||
if full.FinalPrompt == "" {
|
||||
t.Fatalf("expected final prompt to be persisted")
|
||||
}
|
||||
if item.CallerID != "caller:test" {
|
||||
t.Fatalf("expected caller hash persisted in summary, got %#v", item.CallerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartChatHistoryRecoversFromTransientWriteFailure(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
restore := blockChatHistoryDetailDir(t, historyStore.DetailDir())
|
||||
t.Cleanup(restore)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
a := &auth.RequestAuth{
|
||||
CallerID: "caller:test",
|
||||
AccountID: "acct:test",
|
||||
}
|
||||
stdReq := util.StandardRequest{
|
||||
ResponseModel: "deepseek-chat",
|
||||
Stream: true,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
FinalPrompt: "hello",
|
||||
}
|
||||
|
||||
session := startChatHistory(historyStore, req, a, stdReq)
|
||||
if session == nil {
|
||||
t.Fatalf("expected session even when initial persistence fails")
|
||||
}
|
||||
if session.disabled {
|
||||
t.Fatalf("expected session to remain active after transient start failure")
|
||||
}
|
||||
if session.entryID == "" {
|
||||
t.Fatalf("expected session entry id to be retained")
|
||||
}
|
||||
if err := historyStore.Err(); err != nil {
|
||||
t.Fatalf("transient start failure should not latch store error: %v", err)
|
||||
}
|
||||
|
||||
session.lastPersist = time.Now().Add(-time.Second)
|
||||
session.progress("thinking", "partial")
|
||||
if session.disabled {
|
||||
t.Fatalf("expected session to remain active after transient update failure")
|
||||
}
|
||||
if session.entryID == "" {
|
||||
t.Fatalf("expected session entry id to remain set after update failure")
|
||||
}
|
||||
if err := historyStore.Err(); err != nil {
|
||||
t.Fatalf("transient update failure should not latch store error: %v", err)
|
||||
}
|
||||
|
||||
restore()
|
||||
|
||||
session.success(http.StatusOK, "thinking", "final answer", "stop", map[string]any{"total_tokens": 7})
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed after restore: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one persisted item after restore, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := historyStore.Get(session.entryID)
|
||||
if err != nil {
|
||||
t.Fatalf("get restored entry failed: %v", err)
|
||||
}
|
||||
if full.Status != "success" || full.Content != "final answer" {
|
||||
t.Fatalf("expected restored entry to persist final success, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
entry, err := historyStore.Start(chathistory.StartParams{
|
||||
CallerID: "caller:test",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start history failed: %v", err)
|
||||
}
|
||||
session := &chatHistorySession{
|
||||
store: historyStore,
|
||||
entryID: entry.ID,
|
||||
startedAt: time.Now(),
|
||||
lastPersist: time.Now(),
|
||||
finalPrompt: "hello",
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil).WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-chat", "prompt", false, false, nil, session)
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
full, err := historyStore.Get(snapshot.Items[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get detail failed: %v", err)
|
||||
}
|
||||
if full.Status != "stopped" {
|
||||
t.Fatalf("expected stopped status, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsSkipsAdminWebUISource(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi there"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(adminWebUISourceHeader, adminWebUISourceValue)
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected admin webui source to be skipped, got %#v", snapshot.Items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsSkipsHistoryWhenDisabled(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
if _, err := historyStore.SetLimit(chathistory.DisabledLimit); err != nil {
|
||||
t.Fatalf("disable history store failed: %v", err)
|
||||
}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi there"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected disabled history to stay empty, got %#v", snapshot.Items)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,14 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
finalThinking string
|
||||
finalText string
|
||||
finalFinishReason string
|
||||
finalUsage map[string]any
|
||||
finalErrorStatus int
|
||||
finalErrorMessage string
|
||||
finalErrorCode string
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -98,9 +106,32 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
|
||||
s.finalErrorStatus = status
|
||||
s.finalErrorMessage = message
|
||||
s.finalErrorCode = code
|
||||
s.sendChunk(map[string]any{
|
||||
"status_code": status,
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) resetStreamToolCallState() {
|
||||
s.streamToolCallIDs = map[int]string{}
|
||||
s.streamToolNames = map[int]string{}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
s.finalThinking = finalThinking
|
||||
s.finalText = finalText
|
||||
detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
@@ -140,6 +171,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||
nil,
|
||||
))
|
||||
s.resetStreamToolCallState()
|
||||
}
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
@@ -168,7 +200,24 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
if len(detected.Calls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(finalText) == "" {
|
||||
status := http.StatusTooManyRequests
|
||||
message := "Upstream model returned empty output."
|
||||
code := "upstream_empty_output"
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
message = "Upstream model returned reasoning without visible output."
|
||||
}
|
||||
if finishReason == "content_filter" {
|
||||
status = http.StatusBadRequest
|
||||
message = "Upstream content filtered the response and returned no output."
|
||||
code = "content_filter"
|
||||
}
|
||||
s.sendFailedChunk(status, message, code)
|
||||
return
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
s.finalFinishReason = finishReason
|
||||
s.finalUsage = usage
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
@@ -184,6 +233,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter {
|
||||
if strings.TrimSpace(s.text.String()) == "" {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||
}
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
@@ -263,6 +315,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
s.resetStreamToolCallState()
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
|
||||
31
internal/adapter/openai/citation_links.go
Normal file
31
internal/adapter/openai/citation_links.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var citationMarkerPattern = regexp.MustCompile(`(?i)\[citation:\s*(\d+)\]`)
|
||||
|
||||
func replaceCitationMarkersWithLinks(text string, links map[int]string) string {
|
||||
if strings.TrimSpace(text) == "" || len(links) == 0 {
|
||||
return text
|
||||
}
|
||||
return citationMarkerPattern.ReplaceAllStringFunc(text, func(match string) string {
|
||||
sub := citationMarkerPattern.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(sub[1]))
|
||||
if err != nil || idx <= 0 {
|
||||
return match
|
||||
}
|
||||
url := strings.TrimSpace(links[idx])
|
||||
if url == "" {
|
||||
return match
|
||||
}
|
||||
return fmt.Sprintf("[%d](%s)", idx, url)
|
||||
})
|
||||
}
|
||||
28
internal/adapter/openai/citation_links_test.go
Normal file
28
internal/adapter/openai/citation_links_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestReplaceCitationMarkersWithLinks(t *testing.T) {
|
||||
raw := "这是一条更新[citation:1],更多信息见[citation:2]。"
|
||||
links := map[int]string{
|
||||
1: "https://example.com/news-1",
|
||||
2: "https://example.com/news-2",
|
||||
}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "这是一条更新[1](https://example.com/news-1),更多信息见[2](https://example.com/news-2)。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceCitationMarkersWithLinksKeepsUnknownIndex(t *testing.T) {
|
||||
raw := "只有一个来源[citation:1],未知来源[citation:3]。"
|
||||
links := map[int]string{1: "https://example.com/a"}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "只有一个来源[1](https://example.com/a),未知来源[citation:3]。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ type AuthResolver interface {
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
UploadFile(ctx context.Context, a *auth.RequestAuth, req deepseek.UploadFileRequest, maxAttempts int) (*deepseek.UploadFileResult, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*deepseek.DeleteSessionResult, error)
|
||||
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||
@@ -33,6 +34,8 @@ type ConfigReader interface {
|
||||
EmbeddingsProvider() string
|
||||
AutoDeleteMode() string
|
||||
AutoDeleteSessions() bool
|
||||
HistorySplitEnabled() bool
|
||||
HistorySplitTriggerAfterTurns() int
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
|
||||
@@ -3,13 +3,15 @@ package openai
|
||||
import "testing"
|
||||
|
||||
type mockOpenAIConfig struct {
|
||||
aliases map[string]string
|
||||
wideInput bool
|
||||
autoDeleteMode string
|
||||
toolMode string
|
||||
earlyEmit string
|
||||
responsesTTL int
|
||||
embedProv string
|
||||
aliases map[string]string
|
||||
wideInput bool
|
||||
autoDeleteMode string
|
||||
toolMode string
|
||||
earlyEmit string
|
||||
responsesTTL int
|
||||
embedProv string
|
||||
historySplitEnabled bool
|
||||
historySplitTurns int
|
||||
}
|
||||
|
||||
func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases }
|
||||
@@ -27,7 +29,14 @@ func (m mockOpenAIConfig) AutoDeleteMode() string {
|
||||
}
|
||||
return m.autoDeleteMode
|
||||
}
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
func (m mockOpenAIConfig) HistorySplitEnabled() bool { return m.historySplitEnabled }
|
||||
func (m mockOpenAIConfig) HistorySplitTriggerAfterTurns() int {
|
||||
if m.historySplitTurns <= 0 {
|
||||
return 1
|
||||
}
|
||||
return m.historySplitTurns
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
cfg := mockOpenAIConfig{
|
||||
|
||||
@@ -26,8 +26,13 @@ func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, openAIGeneralMaxSize)
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "too large") {
|
||||
writeOpenAIError(w, http.StatusRequestEntityTooLarge, "request body too large")
|
||||
return
|
||||
}
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
382
internal/adapter/openai/file_inline_upload.go
Normal file
382
internal/adapter/openai/file_inline_upload.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
const maxInlineFilesPerRequest = 50
|
||||
|
||||
type inlineFileUploadError struct {
|
||||
status int
|
||||
message string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *inlineFileUploadError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if strings.TrimSpace(e.message) != "" {
|
||||
return e.message
|
||||
}
|
||||
if e.err != nil {
|
||||
return e.err.Error()
|
||||
}
|
||||
return "inline file processing failed"
|
||||
}
|
||||
|
||||
type inlineUploadState struct {
|
||||
ctx context.Context
|
||||
handler *Handler
|
||||
auth *auth.RequestAuth
|
||||
uploadedByID map[string]string
|
||||
uploadCount int
|
||||
}
|
||||
|
||||
type inlineDecodedFile struct {
|
||||
Data []byte
|
||||
ContentType string
|
||||
Filename string
|
||||
ReplacementType string
|
||||
}
|
||||
|
||||
func (h *Handler) preprocessInlineFileInputs(ctx context.Context, a *auth.RequestAuth, req map[string]any) error {
|
||||
if h == nil || h.DS == nil || len(req) == 0 {
|
||||
return nil
|
||||
}
|
||||
state := &inlineUploadState{
|
||||
ctx: ctx,
|
||||
handler: h,
|
||||
auth: a,
|
||||
uploadedByID: map[string]string{},
|
||||
}
|
||||
for _, key := range []string{"messages", "input", "attachments"} {
|
||||
if raw, ok := req[key]; ok {
|
||||
updated, err := state.walk(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req[key] = updated
|
||||
}
|
||||
}
|
||||
if refIDs := collectOpenAIRefFileIDs(req); len(refIDs) > 0 {
|
||||
req["ref_file_ids"] = stringsToAnySlice(refIDs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeOpenAIInlineFileError(w http.ResponseWriter, err error) {
|
||||
inlineErr, ok := err.(*inlineFileUploadError)
|
||||
if !ok || inlineErr == nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to process file input.")
|
||||
return
|
||||
}
|
||||
status := inlineErr.status
|
||||
if status == 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
message := strings.TrimSpace(inlineErr.message)
|
||||
if message == "" {
|
||||
message = "Failed to process file input."
|
||||
}
|
||||
writeOpenAIError(w, status, message)
|
||||
}
|
||||
|
||||
func (s *inlineUploadState) walk(raw any) (any, error) {
|
||||
switch x := raw.(type) {
|
||||
case []any:
|
||||
out := make([]any, len(x))
|
||||
for i, item := range x {
|
||||
updated, err := s.walk(item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[i] = updated
|
||||
}
|
||||
return out, nil
|
||||
case map[string]any:
|
||||
if replacement, replaced, err := s.tryUploadBlock(x); replaced || err != nil {
|
||||
return replacement, err
|
||||
}
|
||||
for _, key := range []string{"messages", "input", "attachments", "content", "files", "items", "data", "source", "file", "image_url"} {
|
||||
if nested, ok := x[key]; ok {
|
||||
updated, err := s.walk(nested)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x[key] = updated
|
||||
}
|
||||
}
|
||||
return x, nil
|
||||
default:
|
||||
return raw, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *inlineUploadState) tryUploadBlock(block map[string]any) (map[string]any, bool, error) {
|
||||
decoded, ok, err := decodeOpenAIInlineFileBlock(block)
|
||||
if err != nil {
|
||||
return nil, true, &inlineFileUploadError{status: http.StatusBadRequest, message: err.Error(), err: err}
|
||||
}
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
if s.uploadCount >= maxInlineFilesPerRequest {
|
||||
return nil, true, fmt.Errorf("exceeded maximum of %d inline files per request", maxInlineFilesPerRequest)
|
||||
}
|
||||
fileID, err := s.uploadInlineFile(decoded)
|
||||
if err != nil {
|
||||
return nil, true, &inlineFileUploadError{status: http.StatusInternalServerError, message: "Failed to upload inline file.", err: err}
|
||||
}
|
||||
s.uploadCount++
|
||||
replacement := map[string]any{
|
||||
"type": decoded.ReplacementType,
|
||||
"file_id": fileID,
|
||||
}
|
||||
if decoded.Filename != "" {
|
||||
replacement["filename"] = decoded.Filename
|
||||
}
|
||||
if decoded.ContentType != "" {
|
||||
replacement["mime_type"] = decoded.ContentType
|
||||
}
|
||||
return replacement, true, nil
|
||||
}
|
||||
|
||||
func (s *inlineUploadState) uploadInlineFile(file inlineDecodedFile) (string, error) {
|
||||
sum := sha256.Sum256(append([]byte(file.ContentType+"\x00"+file.Filename+"\x00"), file.Data...))
|
||||
cacheKey := fmt.Sprintf("%x", sum[:])
|
||||
if fileID, ok := s.uploadedByID[cacheKey]; ok && strings.TrimSpace(fileID) != "" {
|
||||
return fileID, nil
|
||||
}
|
||||
contentType := strings.TrimSpace(file.ContentType)
|
||||
if contentType == "" {
|
||||
contentType = http.DetectContentType(file.Data)
|
||||
}
|
||||
result, err := s.handler.DS.UploadFile(s.ctx, s.auth, deepseek.UploadFileRequest{
|
||||
Filename: file.Filename,
|
||||
ContentType: contentType,
|
||||
Data: file.Data,
|
||||
}, 3)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileID := strings.TrimSpace(result.ID)
|
||||
if fileID == "" {
|
||||
return "", fmt.Errorf("upload succeeded without file id")
|
||||
}
|
||||
s.uploadedByID[cacheKey] = fileID
|
||||
return fileID, nil
|
||||
}
|
||||
|
||||
func decodeOpenAIInlineFileBlock(block map[string]any) (inlineDecodedFile, bool, error) {
|
||||
if block == nil {
|
||||
return inlineDecodedFile{}, false, nil
|
||||
}
|
||||
if strings.TrimSpace(asString(block["file_id"])) != "" {
|
||||
return inlineDecodedFile{}, false, nil
|
||||
}
|
||||
if nested, ok := block["file"].(map[string]any); ok {
|
||||
decoded, matched, err := decodeOpenAIInlineFileBlock(nested)
|
||||
if err != nil || !matched {
|
||||
return decoded, matched, err
|
||||
}
|
||||
if decoded.Filename == "" {
|
||||
decoded.Filename = pickInlineFilename(block, decoded.ContentType, defaultInlinePrefix(decoded.ReplacementType))
|
||||
}
|
||||
return decoded, true, nil
|
||||
}
|
||||
blockType := strings.ToLower(strings.TrimSpace(asString(block["type"])))
|
||||
if raw, matched := extractInlineImageDataURL(block); matched {
|
||||
data, contentType, err := decodeInlinePayload(raw, contentTypeFromMap(block))
|
||||
if err != nil {
|
||||
return inlineDecodedFile{}, true, fmt.Errorf("invalid image input")
|
||||
}
|
||||
return inlineDecodedFile{
|
||||
Data: data,
|
||||
ContentType: contentType,
|
||||
Filename: pickInlineFilename(block, contentType, "image"),
|
||||
ReplacementType: "input_image",
|
||||
}, true, nil
|
||||
}
|
||||
if raw, matched := extractInlineFilePayload(block, blockType); matched {
|
||||
data, contentType, err := decodeInlinePayload(raw, contentTypeFromMap(block))
|
||||
if err != nil {
|
||||
return inlineDecodedFile{}, true, fmt.Errorf("invalid file input")
|
||||
}
|
||||
return inlineDecodedFile{
|
||||
Data: data,
|
||||
ContentType: contentType,
|
||||
Filename: pickInlineFilename(block, contentType, defaultInlinePrefix(blockType)),
|
||||
ReplacementType: "input_file",
|
||||
}, true, nil
|
||||
}
|
||||
return inlineDecodedFile{}, false, nil
|
||||
}
|
||||
|
||||
func extractInlineImageDataURL(block map[string]any) (string, bool) {
|
||||
imageURL := block["image_url"]
|
||||
switch x := imageURL.(type) {
|
||||
case string:
|
||||
if isDataURL(x) {
|
||||
return strings.TrimSpace(x), true
|
||||
}
|
||||
case map[string]any:
|
||||
if raw := strings.TrimSpace(asString(x["url"])); isDataURL(raw) {
|
||||
return raw, true
|
||||
}
|
||||
}
|
||||
if raw := strings.TrimSpace(asString(block["url"])); isDataURL(raw) {
|
||||
return raw, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func extractInlineFilePayload(block map[string]any, blockType string) (string, bool) {
|
||||
for _, value := range []any{block["file_data"], block["base64"], block["data"]} {
|
||||
if raw := strings.TrimSpace(asString(value)); raw != "" {
|
||||
if strings.Contains(blockType, "file") || block["file_data"] != nil || block["filename"] != nil || block["file_name"] != nil || block["name"] != nil {
|
||||
return raw, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func decodeInlinePayload(raw string, explicitContentType string) ([]byte, string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, "", fmt.Errorf("empty payload")
|
||||
}
|
||||
if isDataURL(raw) {
|
||||
return decodeDataURL(raw, explicitContentType)
|
||||
}
|
||||
decoded, err := decodeBase64Flexible(raw)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
contentType := strings.TrimSpace(explicitContentType)
|
||||
if contentType == "" && len(decoded) > 0 {
|
||||
contentType = http.DetectContentType(decoded)
|
||||
}
|
||||
return decoded, contentType, nil
|
||||
}
|
||||
|
||||
func decodeDataURL(raw string, explicitContentType string) ([]byte, string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if !isDataURL(raw) {
|
||||
return nil, "", fmt.Errorf("unsupported data url")
|
||||
}
|
||||
header, payload, ok := strings.Cut(raw, ",")
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("invalid data url")
|
||||
}
|
||||
meta := strings.TrimSpace(strings.TrimPrefix(header, "data:"))
|
||||
contentType := strings.TrimSpace(explicitContentType)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
if meta != "" {
|
||||
parts := strings.Split(meta, ";")
|
||||
if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" {
|
||||
contentType = strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(strings.ToLower(meta), ";base64") {
|
||||
decoded, err := decodeBase64Flexible(payload)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return decoded, contentType, nil
|
||||
}
|
||||
decoded, err := url.PathUnescape(payload)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return []byte(decoded), contentType, nil
|
||||
}
|
||||
|
||||
func decodeBase64Flexible(raw string) ([]byte, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
for _, enc := range []*base64.Encoding{base64.StdEncoding, base64.RawStdEncoding, base64.URLEncoding, base64.RawURLEncoding} {
|
||||
decoded, err := enc.DecodeString(raw)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("invalid base64 payload")
|
||||
}
|
||||
|
||||
func contentTypeFromMap(block map[string]any) string {
|
||||
for _, value := range []any{block["mime_type"], block["mimeType"], block["content_type"], block["contentType"], block["media_type"], block["mediaType"]} {
|
||||
if contentType := strings.TrimSpace(asString(value)); contentType != "" {
|
||||
return contentType
|
||||
}
|
||||
}
|
||||
if imageURL, ok := block["image_url"].(map[string]any); ok {
|
||||
for _, value := range []any{imageURL["mime_type"], imageURL["mimeType"], imageURL["content_type"], imageURL["contentType"]} {
|
||||
if contentType := strings.TrimSpace(asString(value)); contentType != "" {
|
||||
return contentType
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func pickInlineFilename(block map[string]any, contentType string, prefix string) string {
|
||||
for _, value := range []any{block["filename"], block["file_name"], block["name"]} {
|
||||
if name := strings.TrimSpace(asString(value)); name != "" {
|
||||
return filepath.Base(name)
|
||||
}
|
||||
}
|
||||
if prefix == "" {
|
||||
prefix = "upload"
|
||||
}
|
||||
ext := ".bin"
|
||||
if parsedType := strings.TrimSpace(contentType); parsedType != "" {
|
||||
if comma := strings.Index(parsedType, ";"); comma >= 0 {
|
||||
parsedType = strings.TrimSpace(parsedType[:comma])
|
||||
}
|
||||
if exts, err := mime.ExtensionsByType(parsedType); err == nil && len(exts) > 0 && strings.TrimSpace(exts[0]) != "" {
|
||||
ext = exts[0]
|
||||
}
|
||||
}
|
||||
return prefix + ext
|
||||
}
|
||||
|
||||
func defaultInlinePrefix(blockType string) string {
|
||||
blockType = strings.ToLower(strings.TrimSpace(blockType))
|
||||
if strings.Contains(blockType, "image") {
|
||||
return "image"
|
||||
}
|
||||
return "upload"
|
||||
}
|
||||
|
||||
func isDataURL(raw string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(raw)), "data:")
|
||||
}
|
||||
|
||||
func stringsToAnySlice(items []string) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
trimmed := strings.TrimSpace(item)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
274
internal/adapter/openai/file_inline_upload_test.go
Normal file
274
internal/adapter/openai/file_inline_upload_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type inlineUploadDSStub struct {
|
||||
uploadCalls []deepseek.UploadFileRequest
|
||||
lastCtx context.Context
|
||||
completionReq map[string]any
|
||||
createSession string
|
||||
uploadErr error
|
||||
completionResp *http.Response
|
||||
}
|
||||
|
||||
func (m *inlineUploadDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
if strings.TrimSpace(m.createSession) == "" {
|
||||
return "session-id", nil
|
||||
}
|
||||
return m.createSession, nil
|
||||
}
|
||||
|
||||
func (m *inlineUploadDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m *inlineUploadDSStub) UploadFile(ctx context.Context, _ *auth.RequestAuth, req deepseek.UploadFileRequest, _ int) (*deepseek.UploadFileResult, error) {
|
||||
m.lastCtx = ctx
|
||||
m.uploadCalls = append(m.uploadCalls, req)
|
||||
if m.uploadErr != nil {
|
||||
return nil, m.uploadErr
|
||||
}
|
||||
return &deepseek.UploadFileResult{
|
||||
ID: "file-inline-1",
|
||||
Filename: req.Filename,
|
||||
Bytes: int64(len(req.Data)),
|
||||
Status: "uploaded",
|
||||
Purpose: req.Purpose,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *inlineUploadDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
m.completionReq = payload
|
||||
if m.completionResp != nil {
|
||||
return m.completionResp, nil
|
||||
}
|
||||
return makeOpenAISSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"ok"}`,
|
||||
`data: [DONE]`,
|
||||
), nil
|
||||
}
|
||||
|
||||
func (m *inlineUploadDSStub) DeleteSessionForToken(_ context.Context, _ string, _ string) (*deepseek.DeleteSessionResult, error) {
|
||||
return &deepseek.DeleteSessionResult{Success: true}, nil
|
||||
}
|
||||
|
||||
func (m *inlineUploadDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPreprocessInlineFileInputsReplacesDataURLAndCollectsRefFileIDs(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{DS: ds}
|
||||
req := map[string]any{
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{"url": "data:image/png;base64,QUJDRA=="},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := h.preprocessInlineFileInputs(ctx, &auth.RequestAuth{DeepSeekToken: "token"}, req); err != nil {
|
||||
t.Fatalf("preprocess failed: %v", err)
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.lastCtx != ctx {
|
||||
t.Fatalf("expected upload to use request context")
|
||||
}
|
||||
if ds.uploadCalls[0].ContentType != "image/png" {
|
||||
t.Fatalf("expected image/png, got %q", ds.uploadCalls[0].ContentType)
|
||||
}
|
||||
if ds.uploadCalls[0].Filename != "image.png" {
|
||||
t.Fatalf("expected inferred filename image.png, got %q", ds.uploadCalls[0].Filename)
|
||||
}
|
||||
messages, _ := req["messages"].([]any)
|
||||
first, _ := messages[0].(map[string]any)
|
||||
content, _ := first["content"].([]any)
|
||||
block, _ := content[0].(map[string]any)
|
||||
if block["type"] != "input_image" {
|
||||
t.Fatalf("expected input_image replacement, got %#v", block)
|
||||
}
|
||||
if block["file_id"] != "file-inline-1" {
|
||||
t.Fatalf("expected file-inline-1 replacement id, got %#v", block)
|
||||
}
|
||||
refIDs, _ := req["ref_file_ids"].([]any)
|
||||
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("unexpected ref_file_ids: %#v", req["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreprocessInlineFileInputsDeduplicatesIdenticalPayloads(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{DS: ds}
|
||||
req := map[string]any{
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,QUJDRA=="}},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,QUJDRA=="}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.preprocessInlineFileInputs(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, req); err != nil {
|
||||
t.Fatalf("preprocess failed: %v", err)
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected deduplicated single upload, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
refIDs, _ := req["ref_file_ids"].([]any)
|
||||
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("unexpected ref_file_ids after dedupe: %#v", req["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsUploadsInlineFilesBeforeCompletion(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.completionReq == nil {
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
refIDs, _ := ds.completionReq["ref_file_ids"].([]any)
|
||||
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("unexpected completion ref_file_ids: %#v", ds.completionReq["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesUploadsInlineFilesBeforeCompletion(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
reqBody := `{"model":"deepseek-chat","input":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"input_image","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
refIDs, _ := ds.completionReq["ref_file_ids"].([]any)
|
||||
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("unexpected completion ref_file_ids: %#v", ds.completionReq["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsInlineUploadFailureReturnsBadRequest(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,%%%"}}]}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if ds.completionReq != nil {
|
||||
t.Fatalf("did not expect completion call on upload decode error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesInlineUploadFailureReturnsInternalServerError(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{uploadErr: errors.New("boom")}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
reqBody := `{"model":"deepseek-chat","input":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if ds.completionReq != nil {
|
||||
t.Fatalf("did not expect completion call after upload failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVercelPrepareUploadsInlineFilesBeforeLeasePayload(t *testing.T) {
|
||||
t.Setenv("VERCEL", "1")
|
||||
t.Setenv("DS2API_VERCEL_INTERNAL_SECRET", "stream-secret")
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions?__stream_prepare=1", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("X-Ds2-Internal-Token", "stream-secret")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
payload, _ := out["payload"].(map[string]any)
|
||||
if payload == nil {
|
||||
t.Fatalf("expected payload in prepare response, got %#v", out)
|
||||
}
|
||||
refIDs, _ := payload["ref_file_ids"].([]any)
|
||||
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("unexpected payload ref_file_ids: %#v", payload["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
94
internal/adapter/openai/file_refs.go
Normal file
94
internal/adapter/openai/file_refs.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func collectOpenAIRefFileIDs(req map[string]any) []string {
|
||||
if len(req) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, 4)
|
||||
seen := map[string]struct{}{}
|
||||
for _, key := range []string{
|
||||
"ref_file_ids",
|
||||
"file_ids",
|
||||
"attachments",
|
||||
"messages",
|
||||
"input",
|
||||
} {
|
||||
raw := req[key]
|
||||
if raw == nil {
|
||||
continue
|
||||
}
|
||||
// Skip top-level strings for 'messages' and 'input' as they are likely plain text content,
|
||||
// not file IDs. String file IDs are expected in 'ref_file_ids' or 'file_ids'.
|
||||
if key == "messages" || key == "input" {
|
||||
if _, ok := raw.(string); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
appendOpenAIRefFileIDs(&out, seen, raw)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func appendOpenAIRefFileIDs(out *[]string, seen map[string]struct{}, raw any) {
|
||||
switch x := raw.(type) {
|
||||
case string:
|
||||
addOpenAIRefFileID(out, seen, x)
|
||||
case []string:
|
||||
for _, item := range x {
|
||||
addOpenAIRefFileID(out, seen, item)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
appendOpenAIRefFileIDs(out, seen, item)
|
||||
}
|
||||
case map[string]any:
|
||||
if fileID := strings.TrimSpace(asString(x["file_id"])); fileID != "" {
|
||||
addOpenAIRefFileID(out, seen, fileID)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(strings.TrimSpace(asString(x["type"]))), "file") {
|
||||
if fileID := strings.TrimSpace(asString(x["id"])); fileID != "" {
|
||||
addOpenAIRefFileID(out, seen, fileID)
|
||||
}
|
||||
}
|
||||
if fileMap, ok := x["file"].(map[string]any); ok {
|
||||
if fileID := strings.TrimSpace(asString(fileMap["file_id"])); fileID != "" {
|
||||
addOpenAIRefFileID(out, seen, fileID)
|
||||
}
|
||||
if fileID := strings.TrimSpace(asString(fileMap["id"])); fileID != "" {
|
||||
addOpenAIRefFileID(out, seen, fileID)
|
||||
}
|
||||
}
|
||||
// Recurse into potential containers. Note: we do NOT recurse into 'content' or 'input'
|
||||
// if they are plain strings (handled by the top-level switch), but they are usually
|
||||
// nested inside the map branch anyway.
|
||||
// To be safe, we only recurse into these known container keys.
|
||||
for _, key := range []string{"ref_file_ids", "file_ids", "attachments", "messages", "input", "content", "files", "items", "data", "source"} {
|
||||
if nested, ok := x[key]; ok {
|
||||
// If it's a message content that is a string, we must NOT treat it as an ID.
|
||||
if key == "content" || key == "input" {
|
||||
if _, ok := nested.(string); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
appendOpenAIRefFileIDs(out, seen, nested)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addOpenAIRefFileID(out *[]string, seen map[string]struct{}, fileID string) {
|
||||
fileID = strings.TrimSpace(fileID)
|
||||
if fileID == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[fileID]; ok {
|
||||
return
|
||||
}
|
||||
seen[fileID] = struct{}{}
|
||||
*out = append(*out, fileID)
|
||||
}
|
||||
202
internal/adapter/openai/files_route_test.go
Normal file
202
internal/adapter/openai/files_route_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type managedFilesAuthStub struct{}
|
||||
|
||||
func (managedFilesAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: true,
|
||||
DeepSeekToken: "managed-token",
|
||||
CallerID: "caller:test",
|
||||
AccountID: "acct-123",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (managedFilesAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: true,
|
||||
DeepSeekToken: "managed-token",
|
||||
CallerID: "caller:test",
|
||||
AccountID: "acct-123",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (managedFilesAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type filesRouteDSStub struct {
|
||||
lastReq deepseek.UploadFileRequest
|
||||
upload *deepseek.UploadFileResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *filesRouteDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *filesRouteDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *filesRouteDSStub) UploadFile(_ context.Context, _ *auth.RequestAuth, req deepseek.UploadFileRequest, _ int) (*deepseek.UploadFileResult, error) {
|
||||
m.lastReq = req
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if m.upload != nil {
|
||||
return m.upload, nil
|
||||
}
|
||||
return &deepseek.UploadFileResult{ID: "file-123", Filename: req.Filename, Bytes: int64(len(req.Data)), Purpose: req.Purpose, Status: "uploaded"}, nil
|
||||
}
|
||||
|
||||
func (m *filesRouteDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *filesRouteDSStub) DeleteSessionForToken(_ context.Context, _ string, _ string) (*deepseek.DeleteSessionResult, error) {
|
||||
return &deepseek.DeleteSessionResult{Success: true}, nil
|
||||
}
|
||||
|
||||
func (m *filesRouteDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMultipartUploadRequest(t *testing.T, purpose string, filename string, data []byte) *http.Request {
|
||||
t.Helper()
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if purpose != "" {
|
||||
if err := writer.WriteField("purpose", purpose); err != nil {
|
||||
t.Fatalf("write purpose failed: %v", err)
|
||||
}
|
||||
}
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
t.Fatalf("create form file failed: %v", err)
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
t.Fatalf("write file failed: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer failed: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/files", &body)
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req
|
||||
}
|
||||
|
||||
func TestFilesRouteUploadSuccess(t *testing.T) {
|
||||
ds := &filesRouteDSStub{}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if ds.lastReq.Filename != "notes.txt" {
|
||||
t.Fatalf("expected filename notes.txt, got %q", ds.lastReq.Filename)
|
||||
}
|
||||
if ds.lastReq.Purpose != "assistants" {
|
||||
t.Fatalf("expected purpose assistants, got %q", ds.lastReq.Purpose)
|
||||
}
|
||||
if string(ds.lastReq.Data) != "hello world" {
|
||||
t.Fatalf("unexpected uploaded data: %q", string(ds.lastReq.Data))
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if out["object"] != "file" {
|
||||
t.Fatalf("expected file object, got %#v", out)
|
||||
}
|
||||
if out["id"] != "file-123" {
|
||||
t.Fatalf("expected file id file-123, got %#v", out["id"])
|
||||
}
|
||||
if out["filename"] != "notes.txt" {
|
||||
t.Fatalf("expected filename notes.txt, got %#v", out["filename"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesRouteUploadIncludesAccountIDForManagedAccount(t *testing.T) {
|
||||
ds := &filesRouteDSStub{}
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: managedFilesAuthStub{}, DS: ds}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if out["account_id"] != "acct-123" {
|
||||
t.Fatalf("expected account_id acct-123, got %#v", out["account_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesRouteRejectsNonMultipart(t *testing.T) {
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: &filesRouteDSStub{}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/files", bytes.NewBufferString(`{"purpose":"assistants"}`))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesRouteRequiresFileField(t *testing.T) {
|
||||
h := &Handler{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: &filesRouteDSStub{}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if err := writer.WriteField("purpose", "assistants"); err != nil {
|
||||
t.Fatalf("write field failed: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer failed: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/files", &body)
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
@@ -43,42 +44,69 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, openAIGeneralMaxSize)
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "too large") {
|
||||
writeOpenAIError(w, http.StatusRequestEntityTooLarge, "request body too large")
|
||||
return
|
||||
}
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if err := h.preprocessInlineFileInputs(r.Context(), a, req); err != nil {
|
||||
writeOpenAIInlineFileError(w, err)
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq, err = h.applyHistorySplit(r.Context(), a, stdReq)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
historySession := startChatHistory(h.ChatHistory, r, a, stdReq)
|
||||
|
||||
sessionID, err = h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
h.handleNonStream(w, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -114,30 +142,52 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.error(resp.StatusCode, string(body), "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
|
||||
finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
|
||||
if writeUpstreamEmptyOutputError(w, finalThinking, finalText, result.ContentFilter) {
|
||||
if searchEnabled {
|
||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||
}
|
||||
if shouldWriteUpstreamEmptyOutputError(finalText) {
|
||||
status, message, code := upstreamEmptyOutputDetail(result.ContentFilter, finalText, finalThinking)
|
||||
if historySession != nil {
|
||||
historySession.error(status, message, code, finalThinking, finalText)
|
||||
}
|
||||
writeUpstreamEmptyOutputError(w, finalText, result.ContentFilter)
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
||||
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
||||
finishReason = fr
|
||||
}
|
||||
}
|
||||
if historySession != nil {
|
||||
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText))
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.error(resp.StatusCode, string(body), "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
@@ -188,13 +238,32 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnParsed: func(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
decision := streamRuntime.onParsed(parsed)
|
||||
if historySession != nil {
|
||||
historySession.progress(streamRuntime.thinking.String(), streamRuntime.text.String())
|
||||
}
|
||||
return decision
|
||||
},
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
} else {
|
||||
streamRuntime.finalize("stop")
|
||||
}
|
||||
if historySession == nil {
|
||||
return
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
if streamRuntime.finalErrorMessage != "" {
|
||||
historySession.error(streamRuntime.finalErrorStatus, streamRuntime.finalErrorMessage, streamRuntime.finalErrorCode, streamRuntime.thinking.String(), streamRuntime.text.String())
|
||||
return
|
||||
}
|
||||
historySession.success(http.StatusOK, streamRuntime.finalThinking, streamRuntime.finalText, streamRuntime.finalFinishReason, streamRuntime.finalUsage)
|
||||
},
|
||||
OnContextDone: func() {
|
||||
if historySession != nil {
|
||||
historySession.stopped(streamRuntime.thinking.String(), streamRuntime.text.String(), string(streamengine.StopReasonContextCancelled))
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -27,6 +27,10 @@ func (m *autoDeleteModeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m *autoDeleteModeDSStub) UploadFile(_ context.Context, _ *auth.RequestAuth, _ deepseek.UploadFileRequest, _ int) (*deepseek.UploadFileResult, error) {
|
||||
return &deepseek.UploadFileResult{ID: "file-id", Filename: "file.txt", Bytes: 1, Status: "uploaded"}, nil
|
||||
}
|
||||
|
||||
func (m *autoDeleteModeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
104
internal/adapter/openai/handler_files.go
Normal file
104
internal/adapter/openai/handler_files.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
const openAIUploadMaxMemory = 32 << 20
|
||||
|
||||
func (h *Handler) UploadFile(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if !strings.HasPrefix(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "multipart/form-data") {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "content-type must be multipart/form-data")
|
||||
return
|
||||
}
|
||||
// Enforce a hard cap on the total request body size to prevent OOM
|
||||
r.Body = http.MaxBytesReader(w, r.Body, openAIUploadMaxSize)
|
||||
if err := r.ParseMultipartForm(openAIUploadMaxMemory); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "too large") {
|
||||
writeOpenAIError(w, http.StatusRequestEntityTooLarge, "file size exceeds limit")
|
||||
return
|
||||
}
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid multipart form")
|
||||
return
|
||||
}
|
||||
if r.MultipartForm != nil {
|
||||
defer func() { _ = r.MultipartForm.RemoveAll() }()
|
||||
}
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "file is required")
|
||||
return
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
contentType := strings.TrimSpace(header.Header.Get("Content-Type"))
|
||||
if contentType == "" && len(data) > 0 {
|
||||
contentType = http.DetectContentType(data)
|
||||
}
|
||||
result, err := h.DS.UploadFile(r.Context(), a, deepseek.UploadFileRequest{
|
||||
Filename: header.Filename,
|
||||
ContentType: contentType,
|
||||
Purpose: strings.TrimSpace(r.FormValue("purpose")),
|
||||
Data: data,
|
||||
}, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to upload file.")
|
||||
return
|
||||
}
|
||||
if result != nil && result.AccountID == "" {
|
||||
result.AccountID = a.AccountID
|
||||
}
|
||||
writeJSON(w, http.StatusOK, buildOpenAIFileObject(result))
|
||||
}
|
||||
|
||||
func buildOpenAIFileObject(result *deepseek.UploadFileResult) map[string]any {
|
||||
if result == nil {
|
||||
obj := map[string]any{
|
||||
"id": "",
|
||||
"object": "file",
|
||||
"bytes": 0,
|
||||
"created_at": time.Now().Unix(),
|
||||
"filename": "",
|
||||
"purpose": "",
|
||||
"status": "uploaded",
|
||||
"status_details": nil,
|
||||
}
|
||||
return obj
|
||||
}
|
||||
obj := map[string]any{
|
||||
"id": result.ID,
|
||||
"object": "file",
|
||||
"bytes": result.Bytes,
|
||||
"created_at": time.Now().Unix(),
|
||||
"filename": result.Filename,
|
||||
"purpose": result.Purpose,
|
||||
"status": result.Status,
|
||||
"status_details": nil,
|
||||
}
|
||||
if result.AccountID != "" {
|
||||
obj["account_id"] = result.AccountID
|
||||
}
|
||||
return obj
|
||||
}
|
||||
@@ -9,18 +9,27 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// openAIUploadMaxSize limits total multipart request body size (100 MiB).
|
||||
openAIUploadMaxSize = 100 << 20
|
||||
// openAIGeneralMaxSize limits total JSON request body size (100 MiB).
|
||||
openAIGeneralMaxSize = 100 << 20
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||
// every call-site in this package.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
ChatHistory *chathistory.Store
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
@@ -46,6 +55,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
r.Post("/v1/responses", h.Responses)
|
||||
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
||||
r.Post("/v1/files", h.UploadFile)
|
||||
r.Post("/v1/embeddings", h.Embeddings)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -59,21 +57,6 @@ func parseSSEDataFrames(t *testing.T, body string) ([]map[string]any, bool) {
|
||||
return frames, done
|
||||
}
|
||||
|
||||
func streamHasRawToolJSONContent(frames []map[string]any) bool {
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
content, _ := delta["content"].(string)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func streamHasToolCallsDelta(frames []map[string]any) bool {
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
@@ -101,180 +84,7 @@ func streamFinishReason(frames []map[string]any) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func streamToolCallArgumentChunks(frames []map[string]any) []string {
|
||||
out := make([]string, 0, 4)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
if args, ok := fn["arguments"].(string); ok && args != "" {
|
||||
out = append(out, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
if len(choices) != 1 {
|
||||
t.Fatalf("unexpected choices: %#v", out["choices"])
|
||||
}
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
}
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"先想一下"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["reasoning_content"] != "先想一下" {
|
||||
t.Fatalf("expected reasoning_content, got %#v", msg["reasoning_content"])
|
||||
}
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
}
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] == "tool_calls" {
|
||||
t.Fatalf("expected fenced example to remain content-only, got finish_reason=%#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 0 {
|
||||
t.Fatalf("expected no tool_call field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced example content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -283,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-chat", "prompt", false, false, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -302,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, false, nil, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -313,190 +123,22 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/thinking_content","v":"Only thinking"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid3", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, false, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
large := strings.Repeat("a", 9000)
|
||||
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
|
||||
splitAt := len(payload) / 2
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"思考中"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid4", "deepseek-reasoner", "prompt", true, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
hasThinkingDelta := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if _, ok := delta["reasoning_content"]; ok {
|
||||
hasThinkingDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasThinkingDelta {
|
||||
t.Fatalf("expected reasoning_content delta in reasoner stream: %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid5", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid5b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,7 +152,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-chat", "prompt", false, false, []string{"search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -538,287 +180,6 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "```json") || strings.Contains(got, "\n```\n") {
|
||||
t.Fatalf("expected consumed fenced tool payload to not leave empty code fence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamStandaloneToolCallAfterClosedFenceKeepsFence(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "先给一个代码示例:\n```text\nhello\n```\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7g", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for standalone payload, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```") {
|
||||
t.Fatalf("expected closed fence before standalone tool json to be preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
spaces := strings.Repeat(" ", 200)
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{`+spaces+`"}`,
|
||||
`data: {"p":"response/content","v":"\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文C。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid8", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置正文D。"}`,
|
||||
`data: {"p":"response/content","v":"{'tool_calls':[{'name':'search','input':{'q':'go'}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文E。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid9", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for invalid json, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
|
||||
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -828,7 +189,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-chat", "prompt", false, false, []string{"search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -853,107 +214,50 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||
func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
|
||||
`data: {"p":"response/content","v":"前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
|
||||
`data: {"p":"response/content","v":"中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) == 0 {
|
||||
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
t.Fatalf("unexpected merged arguments stream: %q", joined)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
|
||||
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundSearch := false
|
||||
foundEval := false
|
||||
foundIndex1 := false
|
||||
toolCallsDeltaLens := make([]int, 0, 2)
|
||||
ids := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
rawToolCalls, hasToolCalls := delta["tool_calls"]
|
||||
if !hasToolCalls {
|
||||
continue
|
||||
}
|
||||
toolCalls, _ := rawToolCalls.([]any)
|
||||
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
|
||||
foundIndex1 = true
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, rawCall := range toolCalls {
|
||||
call, _ := rawCall.(map[string]any)
|
||||
id := asString(call["id"])
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
name, _ := fn["name"].(string)
|
||||
switch name {
|
||||
case "search_web":
|
||||
foundSearch = true
|
||||
case "eval_javascript":
|
||||
foundEval = true
|
||||
case "search_webeval_javascript":
|
||||
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
|
||||
}
|
||||
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSearch || !foundEval {
|
||||
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
|
||||
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected two distinct tool call ids, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
|
||||
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
|
||||
}
|
||||
if !foundIndex1 {
|
||||
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if ids[0] == ids[1] {
|
||||
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
289
internal/adapter/openai/history_split.go
Normal file
289
internal/adapter/openai/history_split.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
historySplitFilename = "HISTORY.txt"
|
||||
historySplitContentType = "text/plain; charset=utf-8"
|
||||
historySplitPurpose = "assistants"
|
||||
)
|
||||
|
||||
func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, stdReq util.StandardRequest) (util.StandardRequest, error) {
|
||||
if h == nil || h.DS == nil || h.Store == nil || a == nil {
|
||||
return stdReq, nil
|
||||
}
|
||||
if !h.Store.HistorySplitEnabled() {
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
promptMessages, historyMessages := splitOpenAIHistoryMessages(stdReq.Messages, h.Store.HistorySplitTriggerAfterTurns())
|
||||
if len(historyMessages) == 0 {
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
reasoningContent := extractHistorySplitReasoningContent(historyMessages)
|
||||
historyText := buildOpenAIHistoryTranscript(historyMessages)
|
||||
if strings.TrimSpace(historyText) == "" {
|
||||
return stdReq, errors.New("history split produced empty transcript")
|
||||
}
|
||||
|
||||
result, err := h.DS.UploadFile(ctx, a, deepseek.UploadFileRequest{
|
||||
Filename: historySplitFilename,
|
||||
ContentType: historySplitContentType,
|
||||
Purpose: historySplitPurpose,
|
||||
Data: []byte(historyText),
|
||||
}, 3)
|
||||
if err != nil {
|
||||
return stdReq, fmt.Errorf("upload history file: %w", err)
|
||||
}
|
||||
fileID := strings.TrimSpace(result.ID)
|
||||
if fileID == "" {
|
||||
return stdReq, errors.New("upload history file returned empty file id")
|
||||
}
|
||||
|
||||
stdReq.Messages = promptMessages
|
||||
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
|
||||
stdReq.FinalPrompt, stdReq.ToolNames = buildHistorySplitPrompt(promptMessages, reasoningContent, stdReq.ToolsRaw, stdReq.ToolChoice, stdReq.Thinking)
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
func buildHistorySplitPrompt(messages []any, reasoningContent string, toolsRaw any, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) {
|
||||
if len(messages) == 0 && strings.TrimSpace(reasoningContent) == "" {
|
||||
return "", nil
|
||||
}
|
||||
instruction := historySplitPromptInstruction(thinkingEnabled)
|
||||
withInstruction := make([]any, 0, len(messages)+1)
|
||||
withInstruction = append(withInstruction, map[string]any{
|
||||
"role": "system",
|
||||
"content": instruction,
|
||||
})
|
||||
withInstruction = append(withInstruction, injectHistorySplitReasoningMessage(messages, reasoningContent)...)
|
||||
return buildOpenAIFinalPromptWithPolicy(withInstruction, toolsRaw, "", toolPolicy, false)
|
||||
}
|
||||
|
||||
func historySplitPromptInstruction(thinkingEnabled bool) string {
|
||||
lines := []string{
|
||||
"Follow the instructions in this prompt first. If earlier conversation instructions conflict with this prompt, this prompt wins.",
|
||||
"An attached HISTORY.txt file contains prior conversation history and tool progress; read it first, then answer the latest user request using that history as context.",
|
||||
"Continue the conversation from the full prior context and the latest tool results.",
|
||||
"Treat earlier messages as binding context; answer the user's current request as a continuation, not a restart.",
|
||||
}
|
||||
if thinkingEnabled {
|
||||
lines = append(lines, "Keep reasoning internal. Do not leave the final user-facing answer only in reasoning; always provide the answer in visible assistant content.")
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func splitOpenAIHistoryMessages(messages []any, triggerAfterTurns int) ([]any, []any) {
|
||||
if triggerAfterTurns <= 0 {
|
||||
triggerAfterTurns = 1
|
||||
}
|
||||
lastUserIndex := -1
|
||||
userTurns := 0
|
||||
for i, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "user" {
|
||||
continue
|
||||
}
|
||||
userTurns++
|
||||
lastUserIndex = i
|
||||
}
|
||||
if userTurns <= triggerAfterTurns || lastUserIndex < 0 {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
promptMessages := make([]any, 0, len(messages)-lastUserIndex)
|
||||
historyMessages := make([]any, 0, lastUserIndex)
|
||||
for i, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
if i >= lastUserIndex {
|
||||
promptMessages = append(promptMessages, raw)
|
||||
} else {
|
||||
historyMessages = append(historyMessages, raw)
|
||||
}
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
promptMessages = append(promptMessages, raw)
|
||||
default:
|
||||
if i >= lastUserIndex {
|
||||
promptMessages = append(promptMessages, raw)
|
||||
} else {
|
||||
historyMessages = append(historyMessages, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(promptMessages) == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
return promptMessages, historyMessages
|
||||
}
|
||||
|
||||
func buildOpenAIHistoryTranscript(messages []any) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("# HISTORY.txt\n")
|
||||
b.WriteString("Prior conversation history and tool progress.\n\n")
|
||||
|
||||
entry := 0
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
content := buildOpenAIHistoryEntry(role, msg)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
continue
|
||||
}
|
||||
entry++
|
||||
fmt.Fprintf(&b, "=== %d. %s ===\n%s\n\n", entry, strings.ToUpper(roleLabelForHistory(role)), content)
|
||||
}
|
||||
return strings.TrimSpace(b.String()) + "\n"
|
||||
}
|
||||
|
||||
func buildOpenAIHistoryEntry(role string, msg map[string]any) string {
|
||||
switch role {
|
||||
case "assistant":
|
||||
return strings.TrimSpace(buildAssistantHistoryContent(msg))
|
||||
case "tool", "function":
|
||||
return strings.TrimSpace(buildToolHistoryContent(msg))
|
||||
case "user":
|
||||
return strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
default:
|
||||
return strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
}
|
||||
}
|
||||
|
||||
func buildAssistantHistoryContent(msg map[string]any) string {
|
||||
return strings.TrimSpace(buildAssistantContentForPrompt(msg))
|
||||
}
|
||||
|
||||
func buildToolHistoryContent(msg map[string]any) string {
|
||||
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
parts := make([]string, 0, 2)
|
||||
if name := strings.TrimSpace(asString(msg["name"])); name != "" {
|
||||
parts = append(parts, "name="+name)
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(msg["tool_call_id"])); callID != "" {
|
||||
parts = append(parts, "tool_call_id="+callID)
|
||||
}
|
||||
header := ""
|
||||
if len(parts) > 0 {
|
||||
header = "[" + strings.Join(parts, " ") + "]"
|
||||
}
|
||||
switch {
|
||||
case header != "" && content != "":
|
||||
return header + "\n" + content
|
||||
case header != "":
|
||||
return header
|
||||
default:
|
||||
return content
|
||||
}
|
||||
}
|
||||
|
||||
func extractHistorySplitReasoningContent(messages []any) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "assistant" {
|
||||
continue
|
||||
}
|
||||
reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
|
||||
if reasoning == "" {
|
||||
reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
|
||||
}
|
||||
if reasoning != "" {
|
||||
return reasoning
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func injectHistorySplitReasoningMessage(messages []any, reasoningContent string) []any {
|
||||
reasoningContent = strings.TrimSpace(reasoningContent)
|
||||
if reasoningContent == "" {
|
||||
return messages
|
||||
}
|
||||
reasoningMsg := map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": reasoningContent,
|
||||
}
|
||||
lastUserIndex := lastOpenAIUserMessageIndex(messages)
|
||||
if lastUserIndex < 0 {
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
out = append(out, reasoningMsg)
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
for i, raw := range messages {
|
||||
if i == lastUserIndex {
|
||||
out = append(out, reasoningMsg)
|
||||
}
|
||||
out = append(out, raw)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func lastOpenAIUserMessageIndex(messages []any) int {
|
||||
last := -1
|
||||
for i, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) == "user" {
|
||||
last = i
|
||||
}
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func roleLabelForHistory(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
switch role {
|
||||
case "function":
|
||||
return "tool"
|
||||
case "":
|
||||
return "unknown"
|
||||
default:
|
||||
return role
|
||||
}
|
||||
}
|
||||
|
||||
func prependUniqueRefFileID(existing []string, fileID string) []string {
|
||||
fileID = strings.TrimSpace(fileID)
|
||||
if fileID == "" {
|
||||
return existing
|
||||
}
|
||||
out := make([]string, 0, len(existing)+1)
|
||||
out = append(out, fileID)
|
||||
for _, id := range existing {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" || strings.EqualFold(trimmed, fileID) {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
return out
|
||||
}
|
||||
353
internal/adapter/openai/history_split_test.go
Normal file
353
internal/adapter/openai/history_split_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func historySplitTestMessages() []any {
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"arguments": map[string]any{"query": "docs"},
|
||||
},
|
||||
}
|
||||
return []any{
|
||||
map[string]any{"role": "system", "content": "system instructions"},
|
||||
map[string]any{"role": "user", "content": "first user turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "hidden reasoning",
|
||||
"tool_calls": toolCalls,
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"name": "search",
|
||||
"tool_call_id": "call-1",
|
||||
"content": "tool result",
|
||||
},
|
||||
map[string]any{"role": "user", "content": "latest user turn"},
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIHistoryTranscriptPreservesOrderAndToolHistory(t *testing.T) {
|
||||
promptMessages, historyMessages := splitOpenAIHistoryMessages(historySplitTestMessages(), 1)
|
||||
if len(promptMessages) != 2 {
|
||||
t.Fatalf("expected 2 prompt messages, got %d", len(promptMessages))
|
||||
}
|
||||
if len(historyMessages) != 3 {
|
||||
t.Fatalf("expected 3 history messages, got %d", len(historyMessages))
|
||||
}
|
||||
|
||||
transcript := buildOpenAIHistoryTranscript(historyMessages)
|
||||
if !strings.Contains(transcript, "first user turn") {
|
||||
t.Fatalf("expected user history in transcript, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "<tool_calls>") {
|
||||
t.Fatalf("expected assistant tool_calls in transcript, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "tool_call_id=call-1") {
|
||||
t.Fatalf("expected tool call id in transcript, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "[reasoning_content]") {
|
||||
t.Fatalf("expected reasoning block in HISTORY.txt, got %s", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "hidden reasoning") {
|
||||
t.Fatalf("expected reasoning text in HISTORY.txt, got %s", transcript)
|
||||
}
|
||||
|
||||
userIdx := strings.Index(transcript, "=== 1. USER ===")
|
||||
assistantIdx := strings.Index(transcript, "=== 2. ASSISTANT ===")
|
||||
toolIdx := strings.Index(transcript, "=== 3. TOOL ===")
|
||||
if userIdx < 0 || assistantIdx < 0 || toolIdx < 0 {
|
||||
t.Fatalf("expected ordered role sections, got %s", transcript)
|
||||
}
|
||||
if userIdx >= assistantIdx || assistantIdx >= toolIdx {
|
||||
t.Fatalf("expected USER -> ASSISTANT -> TOOL order, got %s", transcript)
|
||||
}
|
||||
if reasoningIdx := strings.Index(transcript, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(transcript, "<tool_calls>") {
|
||||
t.Fatalf("expected reasoning block before tool calls, got %s", transcript)
|
||||
}
|
||||
reasoning := extractHistorySplitReasoningContent(historyMessages)
|
||||
if reasoning != "hidden reasoning" {
|
||||
t.Fatalf("expected latest assistant reasoning to be extracted, got %q", reasoning)
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false)
|
||||
if !strings.Contains(finalPrompt, "latest user turn") {
|
||||
t.Fatalf("expected latest user turn in final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "first user turn") {
|
||||
t.Fatalf("expected earlier history to be removed from final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "[reasoning_content]") || !strings.Contains(finalPrompt, "hidden reasoning") {
|
||||
t.Fatalf("expected latest assistant reasoning to be attached to prompt, got %s", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "HISTORY.txt") {
|
||||
t.Fatalf("expected history instruction in final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Follow the instructions in this prompt first") {
|
||||
t.Fatalf("expected stronger prompt override in final prompt, got %s", finalPrompt)
|
||||
}
|
||||
if strings.Index(finalPrompt, "Follow the instructions in this prompt first") > strings.Index(finalPrompt, "Continue the conversation") {
|
||||
t.Fatalf("expected history split instruction before continuity instructions, got %s", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) {
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"arguments": map[string]any{"query": "docs"},
|
||||
},
|
||||
}
|
||||
messages := []any{
|
||||
map[string]any{"role": "system", "content": "system instructions"},
|
||||
map[string]any{"role": "user", "content": "first user turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": toolCalls,
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"name": "search",
|
||||
"tool_call_id": "call-1",
|
||||
"content": "tool result",
|
||||
},
|
||||
map[string]any{"role": "user", "content": "middle user turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "middle assistant turn",
|
||||
},
|
||||
map[string]any{"role": "user", "content": "latest user turn"},
|
||||
}
|
||||
|
||||
promptMessages, historyMessages := splitOpenAIHistoryMessages(messages, 1)
|
||||
if len(promptMessages) == 0 || len(historyMessages) == 0 {
|
||||
t.Fatalf("expected both prompt and history messages, got prompt=%d history=%d", len(promptMessages), len(historyMessages))
|
||||
}
|
||||
reasoning := extractHistorySplitReasoningContent(historyMessages)
|
||||
if reasoning != "" {
|
||||
t.Fatalf("expected no reasoning in this fixture, got %q", reasoning)
|
||||
}
|
||||
|
||||
promptText, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false)
|
||||
if !strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected latest user turn in prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "middle user turn") {
|
||||
t.Fatalf("expected middle user turn to be split into history, got %s", promptText)
|
||||
}
|
||||
|
||||
historyText := buildOpenAIHistoryTranscript(historyMessages)
|
||||
if !strings.Contains(historyText, "middle user turn") {
|
||||
t.Fatalf("expected middle user turn in HISTORY.txt, got %s", historyText)
|
||||
}
|
||||
if strings.Contains(historyText, "latest user turn") {
|
||||
t.Fatalf("expected latest user turn to remain in prompt, got %s", historyText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHistorySplitSkipsFirstTurn(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
DS: ds,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
out, err := h.applyHistorySplit(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, stdReq)
|
||||
if err != nil {
|
||||
t.Fatalf("apply history split failed: %v", err)
|
||||
}
|
||||
if len(ds.uploadCalls) != 0 {
|
||||
t.Fatalf("expected no upload on first turn, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if out.FinalPrompt != stdReq.FinalPrompt {
|
||||
t.Fatalf("expected prompt unchanged on first turn")
|
||||
}
|
||||
if len(out.RefFileIDs) != len(stdReq.RefFileIDs) {
|
||||
t.Fatalf("expected ref files unchanged on first turn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: ds,
|
||||
}
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": historySplitTestMessages(),
|
||||
"stream": false,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(string(reqBody)))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
upload := ds.uploadCalls[0]
|
||||
if upload.Filename != "HISTORY.txt" {
|
||||
t.Fatalf("unexpected upload filename: %q", upload.Filename)
|
||||
}
|
||||
if upload.ContentType != "text/plain; charset=utf-8" {
|
||||
t.Fatalf("unexpected content type: %q", upload.ContentType)
|
||||
}
|
||||
if upload.Purpose != "assistants" {
|
||||
t.Fatalf("unexpected purpose: %q", upload.Purpose)
|
||||
}
|
||||
historyText := string(upload.Data)
|
||||
if !strings.Contains(historyText, "first user turn") || !strings.Contains(historyText, "tool result") {
|
||||
t.Fatalf("expected older turns in HISTORY.txt, got %s", historyText)
|
||||
}
|
||||
if strings.Contains(historyText, "latest user turn") {
|
||||
t.Fatalf("expected latest turn to remain in prompt, got %s", historyText)
|
||||
}
|
||||
if ds.completionReq == nil {
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected latest turn in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") {
|
||||
t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") {
|
||||
t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "HISTORY.txt") {
|
||||
t.Fatalf("expected history instruction in completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "Follow the instructions in this prompt first") {
|
||||
t.Fatalf("expected stronger prompt override in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Index(promptText, "Follow the instructions in this prompt first") > strings.Index(promptText, "Continue the conversation") {
|
||||
t.Fatalf("expected history split instruction before continuity instructions, got %s", promptText)
|
||||
}
|
||||
refIDs, _ := ds.completionReq["ref_file_ids"].([]any)
|
||||
if len(refIDs) == 0 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("expected uploaded history file to be first ref_file_id, got %#v", ds.completionReq["ref_file_ids"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: ds,
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": historySplitTestMessages(),
|
||||
"stream": false,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(string(reqBody)))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.completionReq == nil {
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected latest turn in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") {
|
||||
t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") {
|
||||
t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText)
|
||||
}
|
||||
if !strings.Contains(promptText, "Follow the instructions in this prompt first") {
|
||||
t.Fatalf("expected stronger prompt override in completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Index(promptText, "Follow the instructions in this prompt first") > strings.Index(promptText, "Continue the conversation") {
|
||||
t.Fatalf("expected history split instruction before continuity instructions, got %s", promptText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsHistorySplitUploadFailureReturnsInternalServerError(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{uploadErr: context.DeadlineExceeded}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
historySplitEnabled: true,
|
||||
historySplitTurns: 1,
|
||||
},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: ds,
|
||||
}
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": historySplitTestMessages(),
|
||||
"stream": false,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(string(reqBody)))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if ds.completionReq != nil {
|
||||
t.Fatalf("did not expect completion payload on upload failure")
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,21 @@ package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var emptyJSONFencePattern = regexp.MustCompile("(?is)```json\\s*```")
|
||||
var leakedToolCallArrayPattern = regexp.MustCompile(`(?is)\[\{\s*"function"\s*:\s*\{[\s\S]*?\}\s*,\s*"id"\s*:\s*"call[^"]*"\s*,\s*"type"\s*:\s*"function"\s*}\]`)
|
||||
var leakedToolResultBlobPattern = regexp.MustCompile(`(?is)<\s*\|\s*tool\s*\|\s*>\s*\{[\s\S]*?"tool_call_id"\s*:\s*"call[^"]*"\s*}`)
|
||||
|
||||
// leakedMetaMarkerPattern matches DeepSeek special tokens in BOTH forms:
|
||||
var leakedThinkTagPattern = regexp.MustCompile(`(?is)</?\s*think\s*>`)
|
||||
|
||||
// leakedBOSMarkerPattern matches DeepSeek BOS markers in BOTH forms:
|
||||
// - ASCII underscore: <|begin_of_sentence|>
|
||||
// - U+2581 variant: <|begin▁of▁sentence|>
|
||||
var leakedBOSMarkerPattern = regexp.MustCompile(`(?i)<[|\|]\s*begin[_▁]of[_▁]sentence\s*[|\|]>`)
|
||||
|
||||
// leakedMetaMarkerPattern matches the remaining DeepSeek special tokens in BOTH forms:
|
||||
// - ASCII underscore: <|end_of_sentence|>, <|end_of_toolresults|>, <|end_of_instructions|>
|
||||
// - U+2581 variant: <|end▁of▁sentence|>, <|end▁of▁toolresults|>, <|end▁of▁instructions|>
|
||||
var leakedMetaMarkerPattern = regexp.MustCompile(`(?i)<[|\|]\s*(?:assistant|tool|end[_▁]of[_▁]sentence|end[_▁]of[_▁]thinking|end[_▁]of[_▁]toolresults|end[_▁]of[_▁]instructions)\s*[|\|]>`)
|
||||
@@ -35,11 +43,48 @@ func sanitizeLeakedOutput(text string) string {
|
||||
out := emptyJSONFencePattern.ReplaceAllString(text, "")
|
||||
out = leakedToolCallArrayPattern.ReplaceAllString(out, "")
|
||||
out = leakedToolResultBlobPattern.ReplaceAllString(out, "")
|
||||
out = stripDanglingThinkSuffix(out)
|
||||
out = leakedThinkTagPattern.ReplaceAllString(out, "")
|
||||
out = leakedBOSMarkerPattern.ReplaceAllString(out, "")
|
||||
out = leakedMetaMarkerPattern.ReplaceAllString(out, "")
|
||||
out = sanitizeLeakedAgentXMLBlocks(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func stripDanglingThinkSuffix(text string) string {
|
||||
matches := leakedThinkTagPattern.FindAllStringIndex(text, -1)
|
||||
if len(matches) == 0 {
|
||||
return text
|
||||
}
|
||||
depth := 0
|
||||
lastOpen := -1
|
||||
for _, loc := range matches {
|
||||
tag := strings.ToLower(text[loc[0]:loc[1]])
|
||||
compact := strings.ReplaceAll(strings.ReplaceAll(strings.TrimSpace(tag), " ", ""), "\t", "")
|
||||
if strings.HasPrefix(compact, "</") {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
lastOpen = -1
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if depth == 0 {
|
||||
lastOpen = loc[0]
|
||||
}
|
||||
depth++
|
||||
}
|
||||
if depth == 0 || lastOpen < 0 {
|
||||
return text
|
||||
}
|
||||
prefix := text[:lastOpen]
|
||||
if strings.TrimSpace(prefix) == "" {
|
||||
return ""
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func sanitizeLeakedAgentXMLBlocks(text string) string {
|
||||
out := text
|
||||
for _, pattern := range leakedAgentXMLBlockPatterns {
|
||||
|
||||
@@ -26,6 +26,22 @@ func TestSanitizeLeakedOutputRemovesStandaloneMetaMarkers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesThinkAndBosMarkers(t *testing.T) {
|
||||
raw := "A<think>B</think>C<|begin▁of▁sentence|>D<| begin_of_sentence |>E<|begin_of_sentence|>F"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "ABCDEF" {
|
||||
t.Fatalf("unexpected sanitize result for think/BOS markers: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesDanglingThinkBlock(t *testing.T) {
|
||||
raw := "Answer prefix<think>internal reasoning that never closes"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "Answer prefix" {
|
||||
t.Fatalf("unexpected sanitize result for dangling think block: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesAgentXMLLeaks(t *testing.T) {
|
||||
raw := "Done.<attempt_completion><result>Some final answer</result></attempt_completion>"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
const assistantReasoningLabel = "reasoning_content"
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
_ = traceID
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
@@ -55,17 +57,95 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
|
||||
func buildAssistantContentForPrompt(msg map[string]any) string {
|
||||
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
|
||||
switch {
|
||||
case content == "" && toolHistory == "":
|
||||
return ""
|
||||
case content == "":
|
||||
return toolHistory
|
||||
case toolHistory == "":
|
||||
return content
|
||||
default:
|
||||
return content + "\n\n" + toolHistory
|
||||
reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
|
||||
if reasoning == "" {
|
||||
reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
|
||||
}
|
||||
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
|
||||
parts := make([]string, 0, 3)
|
||||
if reasoning != "" {
|
||||
parts = append(parts, formatPromptLabeledBlock(assistantReasoningLabel, reasoning))
|
||||
}
|
||||
if content != "" {
|
||||
parts = append(parts, content)
|
||||
}
|
||||
if toolHistory != "" {
|
||||
parts = append(parts, toolHistory)
|
||||
}
|
||||
switch len(parts) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return parts[0]
|
||||
default:
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIReasoningContentForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n")
|
||||
case map[string]any:
|
||||
return extractOpenAIReasoningTextFromItem(x)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningContentFromMessage(v any) string {
|
||||
switch x := v.(type) {
|
||||
case []any:
|
||||
return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n")
|
||||
case map[string]any:
|
||||
return extractOpenAIReasoningTextFromItem(x)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningPartsFromItems(items []any) []string {
|
||||
parts := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if text := extractOpenAIReasoningTextFromItemMap(item); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningTextFromItemMap(item any) string {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return extractOpenAIReasoningTextFromItem(m)
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningTextFromItem(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(asString(m["type"]))) {
|
||||
case "reasoning", "thinking":
|
||||
for _, key := range []string{"text", "thinking", "content"} {
|
||||
if text := strings.TrimSpace(asString(m[key])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func formatPromptLabeledBlock(label, text string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
text = strings.TrimSpace(text)
|
||||
if label == "" {
|
||||
return text
|
||||
}
|
||||
return "[" + label + "]\n" + text + "\n[/" + label + "]"
|
||||
}
|
||||
|
||||
func buildToolContentForPrompt(msg map[string]any) string {
|
||||
|
||||
@@ -296,3 +296,31 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextE
|
||||
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantReasoningContentPreserved(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "visible answer",
|
||||
"reasoning_content": "internal reasoning",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized assistant message, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, "[reasoning_content]") {
|
||||
t.Fatalf("expected labeled reasoning block in assistant content, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "internal reasoning") {
|
||||
t.Fatalf("expected reasoning text in assistant content, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "visible answer") {
|
||||
t.Fatalf("expected visible answer in assistant content, got %q", content)
|
||||
}
|
||||
if reasoningIdx := strings.Index(content, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(content, "visible answer") {
|
||||
t.Fatalf("expected reasoning block before visible answer, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,22 +5,22 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy())
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string, thinkingEnabled bool) (string, []string) {
|
||||
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy(), thinkingEnabled)
|
||||
}
|
||||
|
||||
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) {
|
||||
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools, toolPolicy)
|
||||
}
|
||||
return deepseek.MessagesPrepare(messages), toolNames
|
||||
return deepseek.MessagesPrepareWithThinking(messages, thinkingEnabled), toolNames
|
||||
}
|
||||
|
||||
// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so
|
||||
// other protocol adapters (for example Gemini) can reuse the same tool/history
|
||||
// normalization logic and remain behavior-compatible with chat/completions.
|
||||
func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID)
|
||||
func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string, thinkingEnabled bool) (string, []string) {
|
||||
return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID, thinkingEnabled)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "", false)
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
@@ -73,8 +73,8 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if !strings.Contains(finalPrompt, "Remember: Output ONLY the <tool_calls>...</tool_calls> XML block when calling tools.") {
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false)
|
||||
if !strings.Contains(finalPrompt, "Remember: The ONLY valid way to use tools is the <tool_calls> XML block at the end of your response.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final tool-call anchor instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") {
|
||||
@@ -87,3 +87,17 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
t.Fatalf("vercel prepare finalPrompt should not require fenced tool calls: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIFinalPromptWithThinkingAddsContinuationContract(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "user", "content": "继续回答上一个问题"},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, nil, "", true)
|
||||
if !strings.Contains(finalPrompt, "Continue the conversation from the full prior context") {
|
||||
t.Fatalf("expected continuation contract in thinking prompt, got=%q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "final user-facing answer only in reasoning") {
|
||||
t.Fatalf("expected visible-answer contract in thinking prompt, got=%q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,6 +156,33 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedA
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectOpenAIRefFileIDs(t *testing.T) {
|
||||
got := collectOpenAIRefFileIDs(map[string]any{
|
||||
"ref_file_ids": []any{"file-top", "file-dup"},
|
||||
"attachments": []any{
|
||||
map[string]any{"file_id": "file-attachment"},
|
||||
},
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_file", "file_id": "file-input"},
|
||||
map[string]any{"type": "input_file", "id": "file-dup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
want := []string{"file-top", "file-dup", "file-attachment", "file-input"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("expected %d file ids, got %#v", len(want), got)
|
||||
}
|
||||
for i, id := range want {
|
||||
if got[i] != id {
|
||||
t.Fatalf("unexpected file ids at %d: got=%#v want=%#v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
|
||||
@@ -65,17 +65,31 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, openAIGeneralMaxSize)
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "too large") {
|
||||
writeOpenAIError(w, http.StatusRequestEntityTooLarge, "request body too large")
|
||||
return
|
||||
}
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if err := h.preprocessInlineFileInputs(r.Context(), a, req); err != nil {
|
||||
writeOpenAIInlineFileError(w, err)
|
||||
return
|
||||
}
|
||||
traceID := requestTraceID(r)
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq, err = h.applyHistorySplit(r.Context(), a, stdReq)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -103,10 +117,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -117,7 +131,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
|
||||
sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
|
||||
if writeUpstreamEmptyOutputError(w, sanitizedThinking, sanitizedText, result.ContentFilter) {
|
||||
if searchEnabled {
|
||||
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
||||
}
|
||||
if writeUpstreamEmptyOutputError(w, sanitizedText, result.ContentFilter) {
|
||||
return
|
||||
}
|
||||
textParsed := toolcall.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||
|
||||
@@ -99,12 +99,36 @@ func newResponsesStreamRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) failResponse(message, code string) {
|
||||
s.failed = true
|
||||
failedResp := map[string]any{
|
||||
"id": s.responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"model": s.model,
|
||||
"status": "failed",
|
||||
"output": []any{},
|
||||
"output_text": "",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(failedResp)
|
||||
}
|
||||
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, code))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true, true)
|
||||
}
|
||||
|
||||
textParsed := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
@@ -121,28 +145,16 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.closeMessageItem()
|
||||
|
||||
if s.toolChoice.IsRequired() && len(detected) == 0 {
|
||||
s.failed = true
|
||||
message := "tool_choice requires at least one valid tool call."
|
||||
failedResp := map[string]any{
|
||||
"id": s.responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"model": s.model,
|
||||
"status": "failed",
|
||||
"output": []any{},
|
||||
"output_text": "",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": "tool_choice_violation",
|
||||
"param": nil,
|
||||
},
|
||||
s.failResponse("tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
return
|
||||
}
|
||||
if len(detected) == 0 && strings.TrimSpace(finalText) == "" {
|
||||
code := "upstream_empty_output"
|
||||
message := "Upstream model returned empty output."
|
||||
if finalThinking != "" {
|
||||
message = "Upstream model returned reasoning without visible output."
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(failedResp)
|
||||
}
|
||||
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation"))
|
||||
s.sendDone()
|
||||
s.failResponse(message, code)
|
||||
return
|
||||
}
|
||||
s.closeIncompleteFunctionItems()
|
||||
@@ -212,7 +224,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
s.emitTextDelta(trimmed)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true)
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true, true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (s *responsesStreamRuntime) sendDone() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool, resetAfterToolCalls bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.emitTextDelta(evt.Content)
|
||||
@@ -56,6 +56,9 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
if resetAfterToolCalls {
|
||||
s.resetStreamToolCallState()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +152,16 @@ func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) resetStreamToolCallState() {
|
||||
s.streamToolCallIDs = map[int]string{}
|
||||
s.functionItemIDs = map[int]string{}
|
||||
s.functionOutputIDs = map[int]int{}
|
||||
s.functionArgs = map[int]string{}
|
||||
s.functionDone = map[int]bool{}
|
||||
s.functionAdded = map[int]bool{}
|
||||
s.functionNames = map[int]string{}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
|
||||
if idx, ok := s.functionOutputIDs[callIndex]; ok {
|
||||
return idx
|
||||
|
||||
@@ -12,149 +12,6 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
|
||||
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
|
||||
}
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
hasFunctionCall := false
|
||||
hasLegacyWrapper := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
if m["type"] == "tool_calls" {
|
||||
hasLegacyWrapper = true
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
|
||||
}
|
||||
if hasLegacyWrapper {
|
||||
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_item.added") {
|
||||
t.Fatalf("expected response.output_item.added event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||
}
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added")
|
||||
hasFunctionCallAdded := false
|
||||
for _, payload := range addedPayloads {
|
||||
item, _ := payload["item"].(map[string]any)
|
||||
if item == nil || asString(item["type"]) != "function_call" {
|
||||
continue
|
||||
}
|
||||
hasFunctionCallAdded = true
|
||||
if asString(item["arguments"]) != "" {
|
||||
t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"])
|
||||
}
|
||||
}
|
||||
if !hasFunctionCallAdded {
|
||||
t.Fatalf("expected function_call output_item.added payload, body=%s", body)
|
||||
}
|
||||
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
}
|
||||
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||
if doneCallID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||
}
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
var completedCallID string
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
completedCallID = strings.TrimSpace(asString(m["call_id"]))
|
||||
if completedCallID != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if completedCallID == "" {
|
||||
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
|
||||
}
|
||||
if completedCallID != doneCallID {
|
||||
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -181,51 +38,6 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -297,120 +109,54 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
|
||||
func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", "thinking...") +
|
||||
sseLine("response/content", "先读取文件。") +
|
||||
sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
streamBody := sseLine("前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>") +
|
||||
sseLine("中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) < 1 {
|
||||
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
body := rec.Body.String()
|
||||
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(doneEvents) < 2 {
|
||||
t.Fatalf("expected at least two function call done events, got %d body=%s", len(doneEvents), body)
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
ids := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
for _, payload := range doneEvents {
|
||||
callID := asString(payload["call_id"])
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "message" {
|
||||
hasMessage = true
|
||||
if _, ok := seen[callID]; ok {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
seen[callID] = struct{}{}
|
||||
ids = append(ids, callID)
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected two distinct call ids, got %#v body=%s", ids, body)
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected function_call events for tool_choice=none, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
if ids[0] == ids[1] {
|
||||
t.Fatalf("expected distinct call ids across blocks, got %#v body=%s", ids, body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,7 +194,7 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||
func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -461,53 +207,13 @@ func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *te
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
sseLine("response/content", "plain text only") +
|
||||
"data: [DONE]\n"
|
||||
streamBody := sseLine("response/thinking_content", "Only thinking") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
@@ -516,31 +222,13 @@ func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
payload, ok := extractSSEEventPayload(body, "response.failed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.failed payload, body=%s", body)
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected function_call events for unknown tool, body=%s", body)
|
||||
errObj, _ := payload["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -559,7 +247,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -586,7 +274,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -597,36 +285,6 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
output, _ := out["output"].([]any)
|
||||
foundFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
foundFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !foundFunctionCall {
|
||||
t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -638,7 +296,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -660,7 +318,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -671,6 +329,28 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/thinking_content","v":"Only thinking"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
@@ -697,10 +377,10 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
func extractSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
out := make([]map[string]any, 0, 2)
|
||||
out := make([]map[string]any, 0, 4)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
|
||||
@@ -24,9 +24,10 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
toolPolicy := util.DefaultToolChoicePolicy()
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy, thinkingEnabled)
|
||||
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
refFileIDs := collectOpenAIRefFileIDs(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_chat",
|
||||
@@ -34,12 +35,14 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: responseModel,
|
||||
Messages: messagesRaw,
|
||||
ToolsRaw: req["tools"],
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
RefFileIDs: refFileIDs,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
@@ -74,12 +77,13 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
if err != nil {
|
||||
return util.StandardRequest{}, err
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy, thinkingEnabled)
|
||||
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
|
||||
if !toolPolicy.IsNone() {
|
||||
toolPolicy.Allowed = namesToSet(toolNames)
|
||||
}
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
refFileIDs := collectOpenAIRefFileIDs(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_responses",
|
||||
@@ -87,12 +91,14 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: model,
|
||||
Messages: messagesRaw,
|
||||
ToolsRaw: req["tools"],
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
RefFileIDs: refFileIDs,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -41,6 +41,36 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIChatRequestCollectsRefFileIDs(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-5-codex",
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "hello"},
|
||||
map[string]any{"type": "input_file", "file_id": "file-msg"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"attachments": []any{
|
||||
map[string]any{"file_id": "file-attachment"},
|
||||
},
|
||||
"ref_file_ids": []any{"file-top", "file-attachment"},
|
||||
}
|
||||
n, err := normalizeOpenAIChatRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if len(n.RefFileIDs) != 3 {
|
||||
t.Fatalf("expected 3 distinct file ids, got %#v", n.RefFileIDs)
|
||||
}
|
||||
if n.RefFileIDs[0] != "file-top" || n.RefFileIDs[1] != "file-attachment" || n.RefFileIDs[2] != "file-msg" {
|
||||
t.Fatalf("unexpected file ids: %#v", n.RefFileIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
|
||||
@@ -50,6 +50,10 @@ func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) UploadFile(_ context.Context, _ *auth.RequestAuth, _ deepseek.UploadFileRequest, _ int) (*deepseek.UploadFileResult, error) {
|
||||
return &deepseek.UploadFileResult{ID: "file-id", Filename: "file.txt", Bytes: 1, Status: "uploaded"}, nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
return m.resp, nil
|
||||
}
|
||||
@@ -142,53 +146,6 @@ func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
content, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
})
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
@@ -239,6 +196,49 @@ func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamEmitsFailureFrameWhenUpstreamOutputEmpty(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if len(frames) != 1 {
|
||||
t.Fatalf("expected one failure frame, got %#v body=%s", frames, rec.Body.String())
|
||||
}
|
||||
last := frames[0]
|
||||
statusCode, ok := last["status_code"].(float64)
|
||||
if !ok || int(statusCode) != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status_code=429, got %#v body=%s", last["status_code"], rec.Body.String())
|
||||
}
|
||||
errObj, _ := last["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesStreamUsageIgnoresBatchAccumulatedTokenUsage(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
|
||||
@@ -60,7 +60,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
start := findToolSegmentStart(state, pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
@@ -74,7 +74,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
safe, hold := splitSafeContentForToolDetection(state, pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
@@ -114,14 +114,10 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
// If the captured text looks like an incomplete XML tool call block,
|
||||
// swallow it to prevent leaking raw XML tags to the client.
|
||||
if hasOpenXMLToolTag(content) {
|
||||
// Drop it silently — incomplete tool call.
|
||||
} else {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
// If capture never resolved into a real tool call, release the
|
||||
// buffered text instead of swallowing it.
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
@@ -130,100 +126,57 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
// Safety: if pending contains XML tool tag fragments (e.g. "tool_calls>"
|
||||
// from a split closing tag), swallow them instead of leaking.
|
||||
if hasOpenXMLToolTag(content) || looksLikeXMLToolTagFragment(content) {
|
||||
// Drop it — likely an incomplete tool call fragment.
|
||||
} else {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
// If pending never resolved into a real tool call, release it as text.
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
func splitSafeContentForToolDetection(state *toolStreamSieveState, s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
// If suspicious content starts at position 0, keep holding until we can
|
||||
// parse a complete tool JSON block or reach stream flush.
|
||||
return "", s
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 {
|
||||
if insideCodeFenceWithState(state, s[:xmlIdx]) {
|
||||
return s, ""
|
||||
}
|
||||
if xmlIdx > 0 {
|
||||
return s[:xmlIdx], s[xmlIdx:]
|
||||
}
|
||||
return "", s
|
||||
}
|
||||
// Also check for partial XML tool tag at end of string.
|
||||
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 && xmlIdx > start {
|
||||
start = xmlIdx
|
||||
}
|
||||
return start
|
||||
return s, ""
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
func findToolSegmentStart(state *toolStreamSieveState, s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
|
||||
bestKeyIdx := -1
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
offset := 0
|
||||
for {
|
||||
bestKeyIdx := -1
|
||||
matchedTag := ""
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower[offset:], tag)
|
||||
if idx >= 0 {
|
||||
idx += offset
|
||||
if bestKeyIdx < 0 || idx < bestKeyIdx {
|
||||
bestKeyIdx = idx
|
||||
matchedTag = tag
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if fnKeyIdx := findQuotedFunctionCallKeyStart(s); fnKeyIdx >= 0 && (bestKeyIdx < 0 || fnKeyIdx < bestKeyIdx) {
|
||||
bestKeyIdx = fnKeyIdx
|
||||
}
|
||||
// Also detect XML tool call tags.
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower, tag)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
// For XML tags, the '<' is itself the segment start.
|
||||
if bestKeyIdx < len(s) && s[bestKeyIdx] == '<' {
|
||||
if fenceStart, ok := openFenceStartBefore(s, bestKeyIdx); ok {
|
||||
return fenceStart
|
||||
if !insideCodeFenceWithState(state, s[:bestKeyIdx]) {
|
||||
return bestKeyIdx
|
||||
}
|
||||
return bestKeyIdx
|
||||
offset = bestKeyIdx + len(matchedTag)
|
||||
}
|
||||
start := strings.LastIndex(s[:bestKeyIdx], "{")
|
||||
if start < 0 {
|
||||
start = bestKeyIdx
|
||||
}
|
||||
// If the keyword matched inside an XML tag (e.g. "tool_calls" in "<tool_calls>"),
|
||||
// back up past the '<' to capture the full tag.
|
||||
if start > 0 && s[start-1] == '<' {
|
||||
start--
|
||||
}
|
||||
if fenceStart, ok := openFenceStartBefore(s, start); ok {
|
||||
return fenceStart
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -232,7 +185,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
// Try XML tool call extraction first.
|
||||
// XML tool call extraction only.
|
||||
if xmlPrefix, xmlCalls, xmlSuffix, xmlReady := consumeXMLToolCapture(captured, toolNames); xmlReady {
|
||||
return xmlPrefix, xmlCalls, xmlSuffix, true
|
||||
}
|
||||
@@ -240,45 +193,5 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if hasOpenXMLToolTag(captured) {
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := -1
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||
keyIdx = idx
|
||||
}
|
||||
}
|
||||
if fnKeyIdx := findQuotedFunctionCallKeyStart(captured); fnKeyIdx >= 0 && (keyIdx < 0 || fnKeyIdx < keyIdx) {
|
||||
keyIdx = fnKeyIdx
|
||||
}
|
||||
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
parsed := toolcall.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
// Parsed as tool-call payload but rejected by schema/policy:
|
||||
// consume it to avoid leaking raw tool_calls JSON to user content.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
// If it has obvious keywords but failed to parse even after loose repair,
|
||||
// we still might want to intercept it if it looks like an attempt at tool call.
|
||||
// For now, keep the original logic but rely on loose JSON repair.
|
||||
return captured, nil, "", true
|
||||
}
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed.Calls, suffixPart, true
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func findQuotedFunctionCallKeyStart(s string) int {
|
||||
lower := strings.ToLower(s)
|
||||
quotedIdx := findFunctionCallKeyStart(lower, `"functioncall"`)
|
||||
bareIdx := findFunctionCallKeyStart(lower, "functioncall")
|
||||
|
||||
// Prefer the quoted JSON key whenever we have a structural match.
|
||||
// Bare-key detection is only for loose payloads where the quoted form
|
||||
// is absent.
|
||||
if quotedIdx >= 0 {
|
||||
return quotedIdx
|
||||
}
|
||||
return bareIdx
|
||||
}
|
||||
|
||||
func findFunctionCallKeyStart(lower, key string) int {
|
||||
for from := 0; from < len(lower); {
|
||||
rel := strings.Index(lower[from:], key)
|
||||
if rel < 0 {
|
||||
return -1
|
||||
}
|
||||
idx := from + rel
|
||||
if isInsideJSONString(lower, idx) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
if !hasJSONObjectContextPrefix(lower[:idx]) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
if !hasJSONKeyBoundary(lower, idx, len(key)) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
j := idx + len(key)
|
||||
for j < len(lower) && (lower[j] == ' ' || lower[j] == '\t' || lower[j] == '\r' || lower[j] == '\n') {
|
||||
j++
|
||||
}
|
||||
if j < len(lower) && lower[j] == ':' {
|
||||
k := j + 1
|
||||
for k < len(lower) && (lower[k] == ' ' || lower[k] == '\t' || lower[k] == '\r' || lower[k] == '\n') {
|
||||
k++
|
||||
}
|
||||
if k < len(lower) && lower[k] != '{' {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
return idx
|
||||
}
|
||||
from = idx + 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func isInsideJSONString(s string, idx int) bool {
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < idx; i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
return inString
|
||||
}
|
||||
|
||||
func hasJSONObjectContextPrefix(prefix string) bool {
|
||||
return strings.LastIndex(prefix, "{") >= 0
|
||||
}
|
||||
|
||||
func hasJSONKeyBoundary(s string, idx, keyLen int) bool {
|
||||
if idx > 0 {
|
||||
prev := s[idx-1]
|
||||
if isLowerAlphaNumeric(prev) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if end := idx + keyLen; end < len(s) {
|
||||
next := s[end]
|
||||
if isLowerAlphaNumeric(next) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isLowerAlphaNumeric(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '_'
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierBareKey(t *testing.T) {
|
||||
input := `{functionCall:{"name":"a","arguments":"{}"},"message":"literal text: \"functionCall\": not a key"}`
|
||||
|
||||
got := findQuotedFunctionCallKeyStart(input)
|
||||
want := 1
|
||||
if got != want {
|
||||
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierQuotedKey(t *testing.T) {
|
||||
input := `{"functionCall":{"name":"a","arguments":"{}"},"note":"functionCall appears in prose"}`
|
||||
|
||||
got := findQuotedFunctionCallKeyStart(input)
|
||||
want := 1
|
||||
if got != want {
|
||||
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
@@ -2,48 +2,6 @@ package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
trimmedPrefix := strings.TrimRight(prefix, " \t\r\n")
|
||||
fenceIdx := strings.LastIndex(trimmedPrefix, "```")
|
||||
@@ -67,18 +25,3 @@ func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
consumedLeading := len(suffix) - len(trimmedSuffix)
|
||||
return trimmedPrefix[:fenceIdx], suffix[consumedLeading+3:]
|
||||
}
|
||||
|
||||
func openFenceStartBefore(s string, pos int) (int, bool) {
|
||||
if pos <= 0 || pos > len(s) {
|
||||
return -1, false
|
||||
}
|
||||
segment := s[:pos]
|
||||
lastFence := strings.LastIndex(segment, "```")
|
||||
if lastFence < 0 {
|
||||
return -1, false
|
||||
}
|
||||
if strings.Count(segment, "```")%2 == 1 {
|
||||
return lastFence, true
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
@@ -6,19 +6,21 @@ import (
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
pendingToolRaw string
|
||||
pendingToolCalls []toolcall.ParsedToolCall
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
codeFenceStack []int
|
||||
codeFencePendingTicks int
|
||||
codeFenceLineStart bool
|
||||
pendingToolRaw string
|
||||
pendingToolCalls []toolcall.ParsedToolCall
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
@@ -33,9 +35,6 @@ type toolCallDelta struct {
|
||||
Arguments string
|
||||
}
|
||||
|
||||
// Keep in sync with JS TOOL_SIEVE_CONTEXT_TAIL_LIMIT.
|
||||
const toolSieveContextTailLimit = 2048
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
s.toolNameSent = false
|
||||
@@ -47,19 +46,112 @@ func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
}
|
||||
|
||||
func (s *toolStreamSieveState) noteText(content string) {
|
||||
if content == "" {
|
||||
if !hasMeaningfulText(content) {
|
||||
return
|
||||
}
|
||||
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||
updateCodeFenceState(s, content)
|
||||
}
|
||||
|
||||
func appendTail(prev, next string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
combined := prev + next
|
||||
if len(combined) <= max {
|
||||
return combined
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
func hasMeaningfulText(text string) bool {
|
||||
return strings.TrimSpace(text) != ""
|
||||
}
|
||||
|
||||
func insideCodeFenceWithState(state *toolStreamSieveState, text string) bool {
|
||||
if state == nil {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
simulated := simulateCodeFenceState(
|
||||
state.codeFenceStack,
|
||||
state.codeFencePendingTicks,
|
||||
state.codeFenceLineStart,
|
||||
text,
|
||||
)
|
||||
return len(simulated.stack) > 0
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return len(simulateCodeFenceState(nil, 0, true, text).stack) > 0
|
||||
}
|
||||
|
||||
func updateCodeFenceState(state *toolStreamSieveState, text string) {
|
||||
if state == nil || !hasMeaningfulText(text) {
|
||||
return
|
||||
}
|
||||
next := simulateCodeFenceState(
|
||||
state.codeFenceStack,
|
||||
state.codeFencePendingTicks,
|
||||
state.codeFenceLineStart,
|
||||
text,
|
||||
)
|
||||
state.codeFenceStack = next.stack
|
||||
state.codeFencePendingTicks = next.pendingTicks
|
||||
state.codeFenceLineStart = next.lineStart
|
||||
}
|
||||
|
||||
type codeFenceSimulation struct {
|
||||
stack []int
|
||||
pendingTicks int
|
||||
lineStart bool
|
||||
}
|
||||
|
||||
func simulateCodeFenceState(stack []int, pendingTicks int, lineStart bool, text string) codeFenceSimulation {
|
||||
chunk := text
|
||||
nextStack := append([]int(nil), stack...)
|
||||
ticks := pendingTicks
|
||||
atLineStart := lineStart
|
||||
|
||||
flushTicks := func() {
|
||||
if ticks > 0 {
|
||||
if atLineStart && ticks >= 3 {
|
||||
applyFenceMarker(&nextStack, ticks)
|
||||
}
|
||||
atLineStart = false
|
||||
ticks = 0
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(chunk); i++ {
|
||||
ch := chunk[i]
|
||||
if ch == '`' {
|
||||
ticks++
|
||||
continue
|
||||
}
|
||||
flushTicks()
|
||||
switch ch {
|
||||
case '\n', '\r':
|
||||
atLineStart = true
|
||||
case ' ', '\t':
|
||||
if atLineStart {
|
||||
continue
|
||||
}
|
||||
atLineStart = false
|
||||
default:
|
||||
atLineStart = false
|
||||
}
|
||||
}
|
||||
|
||||
return codeFenceSimulation{
|
||||
stack: nextStack,
|
||||
pendingTicks: ticks,
|
||||
lineStart: atLineStart,
|
||||
}
|
||||
}
|
||||
|
||||
func applyFenceMarker(stack *[]int, ticks int) {
|
||||
if stack == nil || ticks <= 0 {
|
||||
return
|
||||
}
|
||||
if len(*stack) == 0 {
|
||||
*stack = append(*stack, ticks)
|
||||
return
|
||||
}
|
||||
top := (*stack)[len(*stack)-1]
|
||||
if ticks >= top {
|
||||
*stack = (*stack)[:len(*stack)-1]
|
||||
return
|
||||
}
|
||||
*stack = append(*stack, ticks)
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
{"<invoke", "</invoke>"},
|
||||
{"<tool_use", "</tool_use>"},
|
||||
// Agent-style: these are XML "tool call" patterns from coding agents.
|
||||
// They get captured → parsed. If parsing fails, the block is consumed
|
||||
// (swallowed) to prevent raw XML from leaking to the client.
|
||||
// They get captured → parsed. If parsing fails, the raw XML is preserved
|
||||
// so the caller can still see the original text.
|
||||
{"<attempt_completion", "</attempt_completion>"},
|
||||
{"<ask_followup_question", "</ask_followup_question>"},
|
||||
{"<new_task", "</new_task>"},
|
||||
@@ -73,31 +73,12 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
// If this block does not look like an executable tool-call payload,
|
||||
// pass it through as normal content (e.g. user-requested XML snippets).
|
||||
if !looksLikeExecutableXMLToolCallBlock(xmlBlock, pair.open) {
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
// Looks like XML tool syntax but failed to parse — consume it to avoid leak.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
// If this block failed to become a tool call, pass it through as text.
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
func looksLikeExecutableXMLToolCallBlock(xmlBlock, openTag string) bool {
|
||||
lower := strings.ToLower(xmlBlock)
|
||||
// Agent wrapper tags are always treated as internal tool-call wrappers.
|
||||
switch openTag {
|
||||
case "<attempt_completion", "<ask_followup_question", "<new_task":
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lower, "<tool_name") ||
|
||||
strings.Contains(lower, "<parameters") ||
|
||||
strings.Contains(lower, `"tool"`) ||
|
||||
strings.Contains(lower, `"tool_name"`) ||
|
||||
strings.Contains(lower, `"name"`)
|
||||
}
|
||||
|
||||
// hasOpenXMLToolTag returns true if captured text contains an XML tool opening tag
|
||||
// whose SPECIFIC closing tag has not appeared yet.
|
||||
func hasOpenXMLToolTag(captured string) bool {
|
||||
@@ -137,32 +118,3 @@ func findPartialXMLToolTagStart(s string) int {
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// looksLikeXMLToolTagFragment returns true if s looks like a fragment from a
|
||||
// split XML tool call tag — for example "tool_calls>" or "/tool_call>\n".
|
||||
// These fragments arise when '<' was consumed separately and the tail remains.
|
||||
func looksLikeXMLToolTagFragment(s string) bool {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
// Check for closing tag tails like "tool_calls>" or "/tool_calls>"
|
||||
fragments := []string{
|
||||
"tool_calls>", "tool_call>", "/tool_calls>", "/tool_call>",
|
||||
"function_calls>", "function_call>", "/function_calls>", "/function_call>",
|
||||
"invoke>", "/invoke>", "tool_use>", "/tool_use>",
|
||||
"tool_name>", "/tool_name>", "parameters>", "/parameters>",
|
||||
// Agent-style tag fragments
|
||||
"attempt_completion>", "/attempt_completion>",
|
||||
"ask_followup_question>", "/ask_followup_question>",
|
||||
"new_task>", "/new_task>",
|
||||
"result>", "/result>",
|
||||
}
|
||||
for _, f := range fragments {
|
||||
if strings.Contains(lower, f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -42,6 +42,49 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
const toolName = "write_to_file"
|
||||
payload := strings.Repeat("x", 4096)
|
||||
splitAt := len(payload) / 2
|
||||
chunks := []string{
|
||||
"<tool_calls>\n <tool_call>\n <tool_name>" + toolName + "</tool_name>\n <parameters>\n <content><![CDATA[",
|
||||
payload[:splitAt],
|
||||
payload[splitAt:],
|
||||
"]]></content>\n </parameters>\n </tool_call>\n</tool_calls>",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{toolName})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{toolName})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
var gotPayload any
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent.WriteString(evt.Content)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 && gotPayload == nil {
|
||||
gotPayload = evt.ToolCalls[0].Input["content"]
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one long XML tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.Len() != 0 {
|
||||
t.Fatalf("expected no leaked text for long XML tool call, got %q", textContent.String())
|
||||
}
|
||||
got, _ := gotPayload.(string)
|
||||
if got != payload {
|
||||
t.Fatalf("expected long XML payload to survive intact, got len=%d want=%d", len(got), len(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Model outputs some prose then an XML tool call.
|
||||
@@ -121,6 +164,105 @@ func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePassesThroughMalformedExecutableXMLBlock(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<tool_call><parameters>{"path":"README.md"}</parameters></tool_call>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected malformed executable-looking XML to stay text, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != chunk {
|
||||
t.Fatalf("expected malformed executable-looking XML to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePassesThroughFencedXMLToolCallExamples(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
input := strings.Join([]string{
|
||||
"Before first example.\n```",
|
||||
"xml\n<tool_call><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Between examples.\n```xml\n",
|
||||
"<tool_call><tool_name>search</tool_name><parameters>{\"q\":\"golang\"}</parameters></tool_call>\n",
|
||||
"```\nAfter examples.",
|
||||
}, "")
|
||||
|
||||
chunks := []string{
|
||||
"Before first example.\n```",
|
||||
"xml\n<tool_call><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Between examples.\n```xml\n",
|
||||
"<tool_call><tool_name>search</tool_name><parameters>{\"q\":\"golang\"}</parameters></tool_call>\n",
|
||||
"```\nAfter examples.",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file", "search"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file", "search"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent.WriteString(evt.Content)
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected fenced XML examples to stay text, got %d tool calls events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != input {
|
||||
t.Fatalf("expected fenced XML examples to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveKeepsPartialXMLTagInsideFencedExample(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
input := strings.Join([]string{
|
||||
"Example:\n```xml\n<tool_ca",
|
||||
"ll><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Done.",
|
||||
}, "")
|
||||
|
||||
chunks := []string{
|
||||
"Example:\n```xml\n<tool_ca",
|
||||
"ll><tool_name>read_file</tool_name><parameters>{\"path\":\"README.md\"}</parameters></tool_call>\n```\n",
|
||||
"Done.",
|
||||
}
|
||||
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent.WriteString(evt.Content)
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected partial fenced XML to stay text, got %d tool calls events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != input {
|
||||
t.Fatalf("expected partial fenced XML to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Chunk ends with a partial XML tool tag.
|
||||
@@ -147,15 +289,16 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"gemini_function_call_json", `some text {"functionCall":{"name":"search","args":{"q":"latest"}}}`, 10},
|
||||
{"tool_call_tag", "prefix <tool_call>\n", 7},
|
||||
{"invoke_tag", "text <invoke name=\"foo\">body</invoke>", 5},
|
||||
{"xml_inside_code_fence", "```xml\n<tool_call><tool_name>read_file</tool_name></tool_call>\n```", -1},
|
||||
{"function_call_tag", "<function_call name=\"foo\">body</function_call>", 0},
|
||||
{"no_xml", "just plain text", -1},
|
||||
{"gemini_json_no_detect", `some text {"functionCall":{"name":"search"}}`, -1},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := findToolSegmentStart(tc.input)
|
||||
got := findToolSegmentStart(nil, tc.input)
|
||||
if got != tc.want {
|
||||
t.Fatalf("findToolSegmentStart(%q) = %d, want %d", tc.input, got, tc.want)
|
||||
}
|
||||
@@ -163,81 +306,6 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartIgnoresFunctionCallProse(t *testing.T) {
|
||||
input := "Please explain the functionCall API field and how clients should parse it."
|
||||
if got := findToolSegmentStart(input); got != -1 {
|
||||
t.Fatalf("expected no tool segment start for prose, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsQuotedFunctionCallKey(t *testing.T) {
|
||||
input := `prefix {"functionCall": {"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, "{")
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected JSON object start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsLooseFunctionCallKey(t *testing.T) {
|
||||
input := `prefix {functionCall: {"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, "{")
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected JSON object start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartPrefersQuotedFunctionCallOverEarlierBareProse(t *testing.T) {
|
||||
input := `prefix {note} functionCall: docs hint {"functionCall":{"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, `{"functionCall"`)
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected quoted functionCall JSON start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartIgnoresLooseFunctionCallProse(t *testing.T) {
|
||||
input := "Please explain why functionCall: is used in documentation examples."
|
||||
if got := findToolSegmentStart(input); got != -1 {
|
||||
t.Fatalf("expected no tool segment start for prose, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveDoesNotBufferFunctionCallProse(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "Please explain the functionCall API field and keep streaming this sentence."
|
||||
events := processToolSieveChunk(&state, chunk, []string{"search_web"})
|
||||
var text string
|
||||
for _, evt := range events {
|
||||
text += evt.Content
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
t.Fatalf("expected no tool calls for prose, got %#v", evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
if text != chunk {
|
||||
t.Fatalf("expected prose to pass through immediately, got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveDetectsGeminiFunctionCallPayload(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
events := processToolSieveChunk(&state, `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`, []string{"search_web"})
|
||||
events = append(events, flushToolSieve(&state, []string{"search_web"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one tool call from functionCall payload, got events=%#v", events)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(textContent), "functioncall") {
|
||||
t.Fatalf("functionCall json leaked into text content: %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -344,8 +412,8 @@ func TestProcessToolSieveTokenByTokenXMLNoLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that flushToolSieve on incomplete XML does NOT leak the raw XML content.
|
||||
func TestFlushToolSieveIncompleteXMLDoesNotLeak(t *testing.T) {
|
||||
// Test that flushToolSieve on incomplete XML falls back to raw text.
|
||||
func TestFlushToolSieveIncompleteXMLFallsBackToText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// XML block starts but stream ends before completion.
|
||||
chunks := []string{
|
||||
@@ -367,8 +435,8 @@ func TestFlushToolSieveIncompleteXMLDoesNotLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("incomplete XML leaked on flush: %q", textContent)
|
||||
if textContent != strings.Join(chunks, "") {
|
||||
t.Fatalf("expected incomplete XML to fall back to raw text, got %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,10 +473,10 @@ func TestOpeningXMLTagNotLeakedAsContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsAttemptCompletionLeak(t *testing.T) {
|
||||
func TestProcessToolSieveFallsBackToRawAttemptCompletion(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate an agent outputting attempt_completion XML tag
|
||||
// which shouldn't leak to text output, even if it fails to parse as a valid tool.
|
||||
// Simulate an agent outputting attempt_completion XML tag.
|
||||
// If it does not parse as a tool call, it should fall back to raw text.
|
||||
chunks := []string{
|
||||
"Done with task.\n",
|
||||
"<attempt_completion>\n",
|
||||
@@ -432,7 +500,7 @@ func TestProcessToolSieveInterceptsAttemptCompletionLeak(t *testing.T) {
|
||||
t.Fatalf("expected leading text to be emitted, got %q", textContent)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<attempt_completion>") || strings.Contains(textContent, "result>") {
|
||||
t.Fatalf("agent XML tag content leaked to text: %q", textContent)
|
||||
if textContent != strings.Join(chunks, "") {
|
||||
t.Fatalf("expected agent XML to fall back to raw text, got %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,14 +2,26 @@ package openai
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeUpstreamEmptyOutputError(w http.ResponseWriter, thinking, text string, contentFilter bool) bool {
|
||||
if thinking != "" || text != "" {
|
||||
func shouldWriteUpstreamEmptyOutputError(text string) bool {
|
||||
return text == ""
|
||||
}
|
||||
|
||||
func upstreamEmptyOutputDetail(contentFilter bool, text, thinking string) (int, string, string) {
|
||||
_ = text
|
||||
if contentFilter {
|
||||
return http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter"
|
||||
}
|
||||
if thinking != "" {
|
||||
return http.StatusTooManyRequests, "Upstream model returned reasoning without visible output.", "upstream_empty_output"
|
||||
}
|
||||
return http.StatusTooManyRequests, "Upstream model returned empty output.", "upstream_empty_output"
|
||||
}
|
||||
|
||||
func writeUpstreamEmptyOutputError(w http.ResponseWriter, text string, contentFilter bool) bool {
|
||||
if !shouldWriteUpstreamEmptyOutputError(text) {
|
||||
return false
|
||||
}
|
||||
if contentFilter {
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter")
|
||||
return true
|
||||
}
|
||||
writeOpenAIErrorWithCode(w, http.StatusTooManyRequests, "Upstream model returned empty output.", "upstream_empty_output")
|
||||
status, message, code := upstreamEmptyOutputDetail(contentFilter, text, "")
|
||||
writeOpenAIErrorWithCode(w, status, message, code)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -52,6 +52,10 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if err := h.preprocessInlineFileInputs(r.Context(), a, req); err != nil {
|
||||
writeOpenAIInlineFileError(w, err)
|
||||
return
|
||||
}
|
||||
if !util.ToBool(req["stream"]) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
|
||||
@@ -33,6 +33,8 @@ type ConfigStore interface {
|
||||
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||
RuntimeTokenRefreshIntervalHours() int
|
||||
AutoDeleteMode() string
|
||||
HistorySplitEnabled() bool
|
||||
HistorySplitTriggerAfterTurns() int
|
||||
CompatStripReferenceMarkers() bool
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package admin
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatCaller
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatCaller
|
||||
ChatHistory *chathistory.Store
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
@@ -25,6 +28,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Post("/config/import", h.configImport)
|
||||
pr.Get("/config/export", h.configExport)
|
||||
pr.Post("/keys", h.addKey)
|
||||
pr.Put("/keys/{key}", h.updateKey)
|
||||
pr.Delete("/keys/{key}", h.deleteKey)
|
||||
pr.Get("/proxies", h.listProxies)
|
||||
pr.Post("/proxies", h.addProxy)
|
||||
@@ -33,6 +37,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Post("/proxies/test", h.testProxy)
|
||||
pr.Get("/accounts", h.listAccounts)
|
||||
pr.Post("/accounts", h.addAccount)
|
||||
pr.Put("/accounts/{identifier}", h.updateAccount)
|
||||
pr.Delete("/accounts/{identifier}", h.deleteAccount)
|
||||
pr.Put("/accounts/{identifier}/proxy", h.updateAccountProxy)
|
||||
pr.Get("/queue/status", h.queueStatus)
|
||||
@@ -50,6 +55,11 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Get("/export", h.exportConfig)
|
||||
pr.Get("/dev/captures", h.getDevCaptures)
|
||||
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||
pr.Get("/chat-history", h.getChatHistory)
|
||||
pr.Get("/chat-history/{id}", h.getChatHistoryItem)
|
||||
pr.Delete("/chat-history", h.clearChatHistory)
|
||||
pr.Delete("/chat-history/{id}", h.deleteChatHistoryItem)
|
||||
pr.Put("/chat-history/settings", h.updateChatHistorySettings)
|
||||
pr.Get("/version", h.getVersion)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
if pageSize > 5000 {
|
||||
pageSize = 5000
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
reverseAccounts(accounts)
|
||||
@@ -32,6 +32,8 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
for _, acc := range accounts {
|
||||
id := strings.ToLower(acc.Identifier())
|
||||
if strings.Contains(id, q) ||
|
||||
strings.Contains(strings.ToLower(acc.Name), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Remark), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Email), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
||||
filtered = append(filtered, acc)
|
||||
@@ -66,6 +68,8 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"name": acc.Name,
|
||||
"remark": acc.Remark,
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"proxy_id": acc.ProxyID,
|
||||
@@ -112,6 +116,46 @@ func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||
identifier = decoded
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
name, nameOK := fieldStringOptional(req, "name")
|
||||
remark, remarkOK := fieldStringOptional(req, "remark")
|
||||
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for i, acc := range c.Accounts {
|
||||
if !accountMatchesIdentifier(acc, identifier) {
|
||||
continue
|
||||
}
|
||||
if nameOK {
|
||||
c.Accounts[i].Name = name
|
||||
}
|
||||
if remarkOK {
|
||||
c.Accounts[i].Remark = remark
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return newRequestError("账号不存在")
|
||||
})
|
||||
if err != nil {
|
||||
if detail, ok := requestErrorDetail(err); ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": detail})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||
|
||||
88
internal/admin/handler_accounts_crud_test.go
Normal file
88
internal/admin/handler_accounts_crud_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestListAccountsPageSizeCapIs5000(t *testing.T) {
|
||||
accounts := make([]string, 0, 150)
|
||||
for i := range 150 {
|
||||
accounts = append(accounts, fmt.Sprintf(`{"email":"u%d@example.com","password":"pwd"}`, i))
|
||||
}
|
||||
raw := fmt.Sprintf(`{"accounts":[%s]}`, strings.Join(accounts, ","))
|
||||
router := newHTTPAdminHarness(t, raw, &testingDSMock{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodGet, "/accounts?page=1&page_size=200", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
items, _ := payload["items"].([]any)
|
||||
if len(items) != 150 {
|
||||
t.Fatalf("expected all 150 accounts with page_size=200, got %d", len(items))
|
||||
}
|
||||
if ps, _ := payload["page_size"].(float64); ps != 200 {
|
||||
t.Fatalf("expected page_size=200 in response, got %v", payload["page_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAccountsPageSizeAbove5000ClampedTo5000(t *testing.T) {
|
||||
router := newHTTPAdminHarness(t, `{"accounts":[{"email":"u@example.com","password":"pwd"}]}`, &testingDSMock{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodGet, "/accounts?page=1&page_size=9999", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if ps, _ := payload["page_size"].(float64); ps != 5000 {
|
||||
t.Fatalf("expected page_size clamped to 5000, got %v", payload["page_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountMetadataPreservesCredentials(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"email":"u@example.com","name":"old name","remark":"old remark","password":"secret"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Put("/admin/accounts/{identifier}", h.updateAccount)
|
||||
|
||||
body := []byte(`{"name":"new name","remark":"new remark"}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/accounts/u@example.com", strings.NewReader(string(body)))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Accounts) != 1 {
|
||||
t.Fatalf("unexpected accounts after update: %#v", snap.Accounts)
|
||||
}
|
||||
acc := snap.Accounts[0]
|
||||
if acc.Email != "u@example.com" {
|
||||
t.Fatalf("identifier changed unexpectedly: %#v", acc)
|
||||
}
|
||||
if acc.Name != "new name" || acc.Remark != "new remark" {
|
||||
t.Fatalf("metadata update did not persist: %#v", acc)
|
||||
}
|
||||
if acc.Password != "secret" {
|
||||
t.Fatalf("password should be preserved, got %#v", acc)
|
||||
}
|
||||
}
|
||||
134
internal/admin/handler_chat_history.go
Normal file
134
internal/admin/handler_chat_history.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
func (h *Handler) getChatHistory(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{
|
||||
"detail": err.Error(),
|
||||
"path": store.Path(),
|
||||
})
|
||||
return
|
||||
}
|
||||
etag := chathistory.ListETag(snapshot.Revision)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if strings.TrimSpace(r.Header.Get("If-None-Match")) == etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"version": snapshot.Version,
|
||||
"limit": snapshot.Limit,
|
||||
"revision": snapshot.Revision,
|
||||
"items": snapshot.Items,
|
||||
"path": store.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) getChatHistoryItem(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(chi.URLParam(r, "id"))
|
||||
if id == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "history id is required"})
|
||||
return
|
||||
}
|
||||
item, err := store.Get(id)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
etag := chathistory.DetailETag(item.ID, item.Revision)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if strings.TrimSpace(r.Header.Get("If-None-Match")) == etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"item": item,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) clearChatHistory(w http.ResponseWriter, _ *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
if err := store.Clear(); err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": err.Error(), "path": store.Path()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteChatHistoryItem(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(chi.URLParam(r, "id"))
|
||||
if id == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "history id is required"})
|
||||
return
|
||||
}
|
||||
if err := store.Delete(id); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) updateChatHistorySettings(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
snapshot, err := store.SetLimit(body.Limit)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"limit": snapshot.Limit,
|
||||
"revision": snapshot.Revision,
|
||||
"items": snapshot.Items,
|
||||
})
|
||||
}
|
||||
176
internal/admin/handler_chat_history_test.go
Normal file
176
internal/admin/handler_chat_history_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newChatHistoryAdminHarness(t *testing.T) (*Handler, *chathistory.Store) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("write config failed: %v", err)
|
||||
}
|
||||
t.Setenv("DS2API_CONFIG_PATH", configPath)
|
||||
t.Setenv("DS2API_ADMIN_KEY", "admin")
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
store, err := config.LoadStoreWithError()
|
||||
if err != nil {
|
||||
t.Fatalf("load config store failed: %v", err)
|
||||
}
|
||||
historyStore := chathistory.New(filepath.Join(dir, "chat_history.json"))
|
||||
return &Handler{Store: store, ChatHistory: historyStore}, historyStore
|
||||
}
|
||||
|
||||
func TestGetChatHistoryAndUpdateSettings(t *testing.T) {
|
||||
h, historyStore := newChatHistoryAdminHarness(t)
|
||||
entry, err := historyStore.Start(chathistory.StartParams{
|
||||
CallerID: "caller:test",
|
||||
AccountID: "user@example.com",
|
||||
Model: "deepseek-chat",
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start history failed: %v", err)
|
||||
}
|
||||
if _, err := historyStore.Update(entry.ID, chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
Content: "world",
|
||||
Completed: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("update history failed: %v", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/chat-history", nil)
|
||||
req.Header.Set("Authorization", "Bearer admin")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode payload failed: %v", err)
|
||||
}
|
||||
items, _ := payload["items"].([]any)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected one history item, got %#v", payload)
|
||||
}
|
||||
if rec.Header().Get("ETag") == "" {
|
||||
t.Fatalf("expected list etag header")
|
||||
}
|
||||
|
||||
notModifiedReq := httptest.NewRequest(http.MethodGet, "/chat-history", nil)
|
||||
notModifiedReq.Header.Set("Authorization", "Bearer admin")
|
||||
notModifiedReq.Header.Set("If-None-Match", rec.Header().Get("ETag"))
|
||||
notModifiedRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(notModifiedRec, notModifiedReq)
|
||||
if notModifiedRec.Code != http.StatusNotModified {
|
||||
t.Fatalf("expected 304, got %d body=%s", notModifiedRec.Code, notModifiedRec.Body.String())
|
||||
}
|
||||
|
||||
itemReq := httptest.NewRequest(http.MethodGet, "/chat-history/"+entry.ID, nil)
|
||||
itemReq.Header.Set("Authorization", "Bearer admin")
|
||||
itemRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(itemRec, itemReq)
|
||||
if itemRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected item 200, got %d body=%s", itemRec.Code, itemRec.Body.String())
|
||||
}
|
||||
if itemRec.Header().Get("ETag") == "" {
|
||||
t.Fatalf("expected detail etag header")
|
||||
}
|
||||
|
||||
updateReq := httptest.NewRequest(http.MethodPut, "/chat-history/settings", bytes.NewReader([]byte(`{"limit":10}`)))
|
||||
updateReq.Header.Set("Authorization", "Bearer admin")
|
||||
updateRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(updateRec, updateReq)
|
||||
if updateRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 from settings update, got %d body=%s", updateRec.Code, updateRec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != 10 {
|
||||
t.Fatalf("expected limit=10, got %d", snapshot.Limit)
|
||||
}
|
||||
|
||||
disableReq := httptest.NewRequest(http.MethodPut, "/chat-history/settings", bytes.NewReader([]byte(`{"limit":0}`)))
|
||||
disableReq.Header.Set("Authorization", "Bearer admin")
|
||||
disableRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(disableRec, disableReq)
|
||||
if disableRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 from disable update, got %d body=%s", disableRec.Code, disableRec.Body.String())
|
||||
}
|
||||
snapshot, err = historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot after disable failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != chathistory.DisabledLimit {
|
||||
t.Fatalf("expected limit=0, got %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected history preserved when disabled, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAndClearChatHistory(t *testing.T) {
|
||||
h, historyStore := newChatHistoryAdminHarness(t)
|
||||
entryA, err := historyStore.Start(chathistory.StartParams{UserInput: "a"})
|
||||
if err != nil {
|
||||
t.Fatalf("start A failed: %v", err)
|
||||
}
|
||||
if _, err := historyStore.Start(chathistory.StartParams{UserInput: "b"}); err != nil {
|
||||
t.Fatalf("start B failed: %v", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/chat-history/"+entryA.ID, nil)
|
||||
deleteReq.Header.Set("Authorization", "Bearer admin")
|
||||
deleteRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(deleteRec, deleteReq)
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete 200, got %d body=%s", deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one item after delete, got %d", len(snapshot.Items))
|
||||
}
|
||||
|
||||
clearReq := httptest.NewRequest(http.MethodDelete, "/chat-history", nil)
|
||||
clearReq.Header.Set("Authorization", "Bearer admin")
|
||||
clearRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(clearRec, clearReq)
|
||||
if clearRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected clear 200, got %d body=%s", clearRec.Code, clearRec.Body.String())
|
||||
}
|
||||
|
||||
snapshot, err = historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected empty items after clear, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
@@ -53,25 +53,12 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
next.Accounts = normalizeAndDedupeAccounts(next.Accounts)
|
||||
next.VercelSyncHash = c.VercelSyncHash
|
||||
next.VercelSyncTime = c.VercelSyncTime
|
||||
importedKeys = len(next.Keys)
|
||||
importedKeys = len(next.APIKeys)
|
||||
importedAccounts = len(next.Accounts)
|
||||
} else {
|
||||
existingKeys := map[string]struct{}{}
|
||||
for _, k := range next.Keys {
|
||||
existingKeys[k] = struct{}{}
|
||||
}
|
||||
for _, k := range incoming.Keys {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
existingKeys[key] = struct{}{}
|
||||
next.Keys = append(next.Keys, key)
|
||||
importedKeys++
|
||||
}
|
||||
var changed int
|
||||
next.APIKeys, changed = mergeAPIKeysPreferStructured(next.APIKeys, incoming.APIKeys)
|
||||
importedKeys += changed
|
||||
|
||||
existingAccounts := map[string]struct{}{}
|
||||
for _, acc := range next.Accounts {
|
||||
|
||||
@@ -11,6 +11,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"api_keys": snap.APIKeys,
|
||||
"accounts": []map[string]any{},
|
||||
"proxies": []map[string]any{},
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
@@ -37,6 +38,8 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
accounts = append(accounts, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"name": acc.Name,
|
||||
"remark": acc.Remark,
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"proxy_id": acc.ProxyID,
|
||||
|
||||
@@ -19,7 +19,9 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
old := h.Store.Snapshot()
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
if apiKeys, ok := toAPIKeys(req["api_keys"]); ok {
|
||||
c.APIKeys = apiKeys
|
||||
} else if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
c.Keys = keys
|
||||
}
|
||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||
@@ -78,17 +80,19 @@ func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key, _ := req["key"].(string)
|
||||
key = strings.TrimSpace(key)
|
||||
name := fieldString(req, "name")
|
||||
remark := fieldString(req, "remark")
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, k := range c.Keys {
|
||||
if k == key {
|
||||
for _, item := range c.APIKeys {
|
||||
if item.Key == key {
|
||||
return fmt.Errorf("key 已存在")
|
||||
}
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
c.APIKeys = append(c.APIKeys, config.APIKey{Key: key, Name: name, Remark: remark})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -98,12 +102,25 @@ func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
func (h *Handler) updateKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := strings.TrimSpace(chi.URLParam(r, "key"))
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "key 不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
name, nameOK := fieldStringOptional(req, "name")
|
||||
remark, remarkOK := fieldStringOptional(req, "remark")
|
||||
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, k := range c.Keys {
|
||||
if k == key {
|
||||
for i, item := range c.APIKeys {
|
||||
if item.Key == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
@@ -111,7 +128,35 @@ func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("key 不存在")
|
||||
}
|
||||
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
|
||||
if nameOK {
|
||||
c.APIKeys[idx].Name = name
|
||||
}
|
||||
if remarkOK {
|
||||
c.APIKeys[idx].Remark = remark
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, item := range c.APIKeys {
|
||||
if item.Key == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("key 不存在")
|
||||
}
|
||||
c.APIKeys = append(c.APIKeys[:idx], c.APIKeys[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -129,20 +174,23 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if apiKeys, ok := toAPIKeys(req["api_keys"]); ok {
|
||||
var changed int
|
||||
c.APIKeys, changed = mergeAPIKeysPreferStructured(c.APIKeys, apiKeys)
|
||||
importedKeys += changed
|
||||
}
|
||||
if keys, ok := req["keys"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, k := range c.Keys {
|
||||
existing[k] = true
|
||||
}
|
||||
legacy := make([]config.APIKey, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", k))
|
||||
if key == "" || existing[key] {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
existing[key] = true
|
||||
importedKeys++
|
||||
legacy = append(legacy, config.APIKey{Key: key})
|
||||
}
|
||||
var changed int
|
||||
c.APIKeys, changed = mergeAPIKeysPreferStructured(c.APIKeys, legacy)
|
||||
importedKeys += changed
|
||||
}
|
||||
if accounts, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
|
||||
76
internal/admin/handler_keys_test.go
Normal file
76
internal/admin/handler_keys_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestKeyEndpointsPreserveStructuredMetadata(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Post("/admin/keys", h.addKey)
|
||||
r.Put("/admin/keys/{key}", h.updateKey)
|
||||
r.Delete("/admin/keys/{key}", h.deleteKey)
|
||||
|
||||
addBody := []byte(`{"key":"k2","name":"secondary","remark":"staging"}`)
|
||||
addReq := httptest.NewRequest(http.MethodPost, "/admin/keys", bytes.NewReader(addBody))
|
||||
addRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(addRec, addReq)
|
||||
if addRec.Code != http.StatusOK {
|
||||
t.Fatalf("add status=%d body=%s", addRec.Code, addRec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after add: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("existing metadata was lost after add: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("new metadata was lost after add: %#v", snap.APIKeys[1])
|
||||
}
|
||||
|
||||
updateBody := map[string]any{
|
||||
"name": "primary-updated",
|
||||
"remark": "prod-updated",
|
||||
}
|
||||
updateBytes, _ := json.Marshal(updateBody)
|
||||
updateReq := httptest.NewRequest(http.MethodPut, "/admin/keys/k1", bytes.NewReader(updateBytes))
|
||||
updateRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(updateRec, updateReq)
|
||||
if updateRec.Code != http.StatusOK {
|
||||
t.Fatalf("update status=%d body=%s", updateRec.Code, updateRec.Body.String())
|
||||
}
|
||||
|
||||
snap = h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Key != "k1" || snap.APIKeys[0].Name != "primary-updated" || snap.APIKeys[0].Remark != "prod-updated" {
|
||||
t.Fatalf("metadata update did not persist: %#v", snap.APIKeys[0])
|
||||
}
|
||||
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/admin/keys/k1", nil)
|
||||
deleteRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(deleteRec, deleteReq)
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("delete status=%d body=%s", deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
|
||||
snap = h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 1 || snap.APIKeys[0].Key != "k2" {
|
||||
t.Fatalf("unexpected api keys after delete: %#v", snap.APIKeys)
|
||||
}
|
||||
if len(snap.Keys) != 1 || snap.Keys[0] != "k2" {
|
||||
t.Fatalf("unexpected legacy keys after delete: %#v", snap.Keys)
|
||||
}
|
||||
}
|
||||
@@ -21,16 +21,17 @@ func boolFrom(v any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.CompatConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, map[string]string, map[string]string, error) {
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.CompatConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, *config.HistorySplitConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
compatCfg *config.CompatConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
autoDeleteCfg *config.AutoDeleteConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
compatCfg *config.CompatConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
autoDeleteCfg *config.AutoDeleteConfig
|
||||
historySplitCfg *config.HistorySplitConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
@@ -38,7 +39,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("admin.jwt_expire_hours", n, 1, 720, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.JWTExpireHours = n
|
||||
}
|
||||
@@ -50,33 +51,33 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["account_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.account_max_inflight", n, 1, 256, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.AccountMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["account_max_queue"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.account_max_queue", n, 1, 200000, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.AccountMaxQueue = n
|
||||
}
|
||||
if v, exists := raw["global_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.global_max_inflight", n, 1, 200000, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["token_refresh_interval_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("runtime.token_refresh_interval_hours", n, 1, 720, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.TokenRefreshIntervalHours = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("responses.store_ttl_seconds", n, 30, 86400, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.StoreTTLSeconds = n
|
||||
}
|
||||
@@ -111,7 +112,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if err := config.ValidateTrimmedString("embeddings.provider", p, false); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
@@ -147,7 +148,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["mode"]; exists {
|
||||
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
if err := config.ValidateAutoDeleteMode(mode); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "none"
|
||||
@@ -160,5 +161,24 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
autoDeleteCfg = cfg
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, compatCfg, respCfg, embCfg, autoDeleteCfg, claudeMap, aliasMap, nil
|
||||
if raw, ok := req["history_split"].(map[string]any); ok {
|
||||
cfg := &config.HistorySplitConfig{}
|
||||
if v, exists := raw["enabled"]; exists {
|
||||
b := boolFrom(v)
|
||||
cfg.Enabled = &b
|
||||
}
|
||||
if v, exists := raw["trigger_after_turns"]; exists {
|
||||
n := intFrom(v)
|
||||
if err := config.ValidateIntRange("history_split.trigger_after_turns", n, 1, 1000, true); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
cfg.TriggerAfterTurns = &n
|
||||
}
|
||||
if err := config.ValidateHistorySplitConfig(*cfg); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
historySplitCfg = cfg
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, compatCfg, respCfg, embCfg, autoDeleteCfg, historySplitCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
|
||||
@@ -26,10 +26,14 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
"token_refresh_interval_hours": h.Store.RuntimeTokenRefreshIntervalHours(),
|
||||
},
|
||||
"compat": snap.Compat,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"auto_delete": snap.AutoDelete,
|
||||
"compat": snap.Compat,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"auto_delete": snap.AutoDelete,
|
||||
"history_split": map[string]any{
|
||||
"enabled": h.Store.HistorySplitEnabled(),
|
||||
"trigger_after_turns": h.Store.HistorySplitTriggerAfterTurns(),
|
||||
},
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
|
||||
@@ -47,6 +47,25 @@ func TestGetSettingsIncludesTokenRefreshInterval(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettingsIncludesHistorySplitDefaults(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.getSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
_ = json.Unmarshal(rec.Body.Bytes(), &body)
|
||||
historySplit, _ := body["history_split"].(map[string]any)
|
||||
if got := boolFrom(historySplit["enabled"]); !got {
|
||||
t.Fatalf("expected history_split.enabled=true, body=%v", body)
|
||||
}
|
||||
if got := intFrom(historySplit["trigger_after_turns"]); got != 1 {
|
||||
t.Fatalf("expected history_split.trigger_after_turns=1, got %d body=%v", got, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidation(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
@@ -154,6 +173,30 @@ func TestUpdateSettingsWithoutRuntimeSkipsMergedRuntimeValidation(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsHistorySplit(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
"history_split": map[string]any{
|
||||
"enabled": false,
|
||||
"trigger_after_turns": 3,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snap := h.Store.Snapshot()
|
||||
if snap.HistorySplit.Enabled == nil || *snap.HistorySplit.Enabled {
|
||||
t.Fatalf("expected history_split.enabled=false, got %#v", snap.HistorySplit.Enabled)
|
||||
}
|
||||
if snap.HistorySplit.TriggerAfterTurns == nil || *snap.HistorySplit.TriggerAfterTurns != 3 {
|
||||
t.Fatalf("expected history_split.trigger_after_turns=3, got %#v", snap.HistorySplit.TriggerAfterTurns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsAutoDeleteMode(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"],"auto_delete":{"sessions":true}}`)
|
||||
|
||||
@@ -234,6 +277,75 @@ func TestUpdateSettingsHotReloadTokenRefreshInterval(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigPreservesStructuredAPIKeysWhenBothFieldsPresent(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["legacy"],
|
||||
"api_keys":[{"key":"legacy","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"keys": []any{"legacy", "new-key"},
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "legacy", "name": "primary-updated", "remark": "prod-updated"},
|
||||
map[string]any{"key": "new-key", "name": "secondary", "remark": "staging"},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateConfig(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after config update: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after config update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary-updated" || snap.APIKeys[0].Remark != "prod-updated" {
|
||||
t.Fatalf("structured metadata for existing key was not preserved: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("structured metadata for new key was not preserved: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigLegacyKeysPreserveStructuredMetadata(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"api_keys":[{"key":"legacy","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"keys": []any{"legacy", "new-key"},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateConfig(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after legacy config update: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after legacy config update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("existing structured metadata was lost: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Key != "new-key" || snap.APIKeys[1].Name != "" || snap.APIKeys[1].Remark != "" {
|
||||
t.Fatalf("new legacy key should remain metadata-free: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) {
|
||||
hash := authn.HashAdminPassword("old-password")
|
||||
h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`)
|
||||
@@ -315,6 +427,113 @@ func TestConfigImportMergeAndReplace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportMergePreservesStructuredAPIKeys(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}]
|
||||
}`)
|
||||
|
||||
merge := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "k1", "name": "should-not-overwrite", "remark": "ignored"},
|
||||
map[string]any{"key": "k2", "name": "secondary", "remark": "staging"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mergeBytes, _ := json.Marshal(merge)
|
||||
mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes))
|
||||
mergeRec := httptest.NewRecorder()
|
||||
h.configImport(mergeRec, mergeReq)
|
||||
if mergeRec.Code != http.StatusOK {
|
||||
t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after structured merge: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("existing structured metadata was overwritten: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("new structured metadata was lost: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportMergeUpgradesLegacyAPIKeys(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["legacy"],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
merge := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "legacy", "name": "primary", "remark": "prod"},
|
||||
map[string]any{"key": "new-key", "name": "secondary", "remark": "staging"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mergeBytes, _ := json.Marshal(merge)
|
||||
mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes))
|
||||
mergeRec := httptest.NewRecorder()
|
||||
h.configImport(mergeRec, mergeReq)
|
||||
if mergeRec.Code != http.StatusOK {
|
||||
t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after legacy import merge: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after legacy import merge: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("legacy key metadata was not upgraded: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("new structured metadata was not preserved: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchImportUpgradesLegacyAPIKeys(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["legacy"],
|
||||
"accounts":[]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"keys": []any{"legacy", "new-key"},
|
||||
"api_keys": []any{
|
||||
map[string]any{"key": "legacy", "name": "primary", "remark": "prod"},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/import", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.batchImport(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
snap := h.Store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "legacy" || snap.Keys[1] != "new-key" {
|
||||
t.Fatalf("unexpected keys after batch import: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after batch import: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("legacy key metadata was not upgraded: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "" || snap.APIKeys[1].Remark != "" {
|
||||
t.Fatalf("new batch-imported key should stay metadata-free: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportAppliesTokenRefreshInterval(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, compatCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
adminCfg, runtimeCfg, compatCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, historySplitCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
@@ -67,6 +67,14 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
c.AutoDelete.Mode = autoDeleteCfg.Mode
|
||||
c.AutoDelete.Sessions = autoDeleteCfg.Sessions
|
||||
}
|
||||
if historySplitCfg != nil {
|
||||
if historySplitCfg.Enabled != nil {
|
||||
c.HistorySplit.Enabled = historySplitCfg.Enabled
|
||||
}
|
||||
if historySplitCfg.TriggerAfterTurns != nil {
|
||||
c.HistorySplit.TriggerAfterTurns = historySplitCfg.TriggerAfterTurns
|
||||
}
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
|
||||
@@ -62,6 +62,8 @@ func toAccount(m map[string]any) config.Account {
|
||||
email := fieldString(m, "email")
|
||||
mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile"))
|
||||
return config.Account{
|
||||
Name: fieldString(m, "name"),
|
||||
Remark: fieldString(m, "remark"),
|
||||
Email: email,
|
||||
Mobile: mobile,
|
||||
Password: fieldString(m, "password"),
|
||||
@@ -69,6 +71,116 @@ func toAccount(m map[string]any) config.Account {
|
||||
}
|
||||
}
|
||||
|
||||
func toAPIKeys(v any) ([]config.APIKey, bool) {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]config.APIKey, 0, len(arr))
|
||||
seen := map[string]struct{}{}
|
||||
for _, item := range arr {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
key := fieldString(x, "key")
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, config.APIKey{
|
||||
Key: key,
|
||||
Name: fieldString(x, "name"),
|
||||
Remark: fieldString(x, "remark"),
|
||||
})
|
||||
default:
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", item))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, config.APIKey{Key: key})
|
||||
}
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func normalizeAPIKeyForStorage(item config.APIKey) config.APIKey {
|
||||
return config.APIKey{
|
||||
Key: strings.TrimSpace(item.Key),
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
}
|
||||
}
|
||||
|
||||
func apiKeyHasMetadata(item config.APIKey) bool {
|
||||
return strings.TrimSpace(item.Name) != "" || strings.TrimSpace(item.Remark) != ""
|
||||
}
|
||||
|
||||
func mergeAPIKeysPreferStructured(existing, incoming []config.APIKey) ([]config.APIKey, int) {
|
||||
if len(existing) == 0 && len(incoming) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
merged := make([]config.APIKey, 0, len(existing)+len(incoming))
|
||||
index := make(map[string]int, len(existing)+len(incoming))
|
||||
for _, item := range existing {
|
||||
item = normalizeAPIKeyForStorage(item)
|
||||
if item.Key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := index[item.Key]; ok {
|
||||
continue
|
||||
}
|
||||
index[item.Key] = len(merged)
|
||||
merged = append(merged, item)
|
||||
}
|
||||
|
||||
imported := 0
|
||||
for _, item := range incoming {
|
||||
item = normalizeAPIKeyForStorage(item)
|
||||
if item.Key == "" {
|
||||
continue
|
||||
}
|
||||
if idx, ok := index[item.Key]; ok {
|
||||
keep := merged[idx]
|
||||
next := mergeAPIKeyRecord(keep, item)
|
||||
if next != keep {
|
||||
merged[idx] = next
|
||||
imported++
|
||||
}
|
||||
continue
|
||||
}
|
||||
index[item.Key] = len(merged)
|
||||
merged = append(merged, item)
|
||||
imported++
|
||||
}
|
||||
|
||||
if len(merged) == 0 {
|
||||
return nil, imported
|
||||
}
|
||||
return merged, imported
|
||||
}
|
||||
|
||||
func mergeAPIKeyRecord(existing, incoming config.APIKey) config.APIKey {
|
||||
existing = normalizeAPIKeyForStorage(existing)
|
||||
incoming = normalizeAPIKeyForStorage(incoming)
|
||||
if existing.Key == "" {
|
||||
return incoming
|
||||
}
|
||||
if apiKeyHasMetadata(existing) {
|
||||
return existing
|
||||
}
|
||||
if apiKeyHasMetadata(incoming) {
|
||||
return incoming
|
||||
}
|
||||
return existing
|
||||
}
|
||||
|
||||
func fieldString(m map[string]any, key string) string {
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
@@ -77,6 +189,14 @@ func fieldString(m map[string]any, key string) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
func fieldStringOptional(m map[string]any, key string) (string, bool) {
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return "", false
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v)), true
|
||||
}
|
||||
|
||||
func statusOr(v int, d int) int {
|
||||
if v == 0 {
|
||||
return d
|
||||
@@ -99,6 +219,8 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
||||
}
|
||||
|
||||
func normalizeAccountForStorage(acc config.Account) config.Account {
|
||||
acc.Name = strings.TrimSpace(acc.Name)
|
||||
acc.Remark = strings.TrimSpace(acc.Remark)
|
||||
acc.Email = strings.TrimSpace(acc.Email)
|
||||
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
|
||||
acc.ProxyID = strings.TrimSpace(acc.ProxyID)
|
||||
|
||||
766
internal/chathistory/store.go
Normal file
766
internal/chathistory/store.go
Normal file
@@ -0,0 +1,766 @@
|
||||
package chathistory
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
FileVersion = 2
|
||||
DisabledLimit = 0
|
||||
DefaultLimit = 20
|
||||
MaxLimit = 50
|
||||
defaultPreviewAt = 160
|
||||
)
|
||||
|
||||
var allowedLimits = map[int]struct{}{
|
||||
DisabledLimit: {},
|
||||
10: {},
|
||||
20: {},
|
||||
50: {},
|
||||
}
|
||||
|
||||
var ErrDisabled = errors.New("chat history disabled")
|
||||
|
||||
type Entry struct {
|
||||
ID string `json:"id"`
|
||||
Revision int64 `json:"revision"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CallerID string `json:"caller_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserInput string `json:"user_input,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
FinalPrompt string `json:"final_prompt,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage map[string]any `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type SummaryEntry struct {
|
||||
ID string `json:"id"`
|
||||
Revision int64 `json:"revision"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CallerID string `json:"caller_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserInput string `json:"user_input,omitempty"`
|
||||
Preview string `json:"preview,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
DetailRevision int64 `json:"detail_revision"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Version int `json:"version"`
|
||||
Limit int `json:"limit"`
|
||||
Revision int64 `json:"revision"`
|
||||
Items []SummaryEntry `json:"items"`
|
||||
}
|
||||
|
||||
type StartParams struct {
|
||||
CallerID string
|
||||
AccountID string
|
||||
Model string
|
||||
Stream bool
|
||||
UserInput string
|
||||
Messages []Message
|
||||
FinalPrompt string
|
||||
}
|
||||
|
||||
type UpdateParams struct {
|
||||
Status string
|
||||
ReasoningContent string
|
||||
Content string
|
||||
Error string
|
||||
StatusCode int
|
||||
ElapsedMs int64
|
||||
FinishReason string
|
||||
Usage map[string]any
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type detailEnvelope struct {
|
||||
Version int `json:"version"`
|
||||
Item Entry `json:"item"`
|
||||
}
|
||||
|
||||
type legacyFile struct {
|
||||
Version int `json:"version"`
|
||||
Limit int `json:"limit"`
|
||||
Items []Entry `json:"items"`
|
||||
}
|
||||
|
||||
type legacyProbe struct {
|
||||
Items []map[string]json.RawMessage `json:"items"`
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
detailDir string
|
||||
state File
|
||||
details map[string]Entry
|
||||
dirty map[string]struct{}
|
||||
deleted map[string]struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func New(path string) *Store {
|
||||
s := &Store{
|
||||
path: strings.TrimSpace(path),
|
||||
detailDir: strings.TrimSpace(path) + ".d",
|
||||
state: File{
|
||||
Version: FileVersion,
|
||||
Limit: DefaultLimit,
|
||||
Revision: 0,
|
||||
Items: []SummaryEntry{},
|
||||
},
|
||||
details: map[string]Entry{},
|
||||
dirty: map[string]struct{}{},
|
||||
deleted: map[string]struct{}{},
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.err = s.loadLocked()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Store) Path() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.path
|
||||
}
|
||||
|
||||
func (s *Store) DetailDir() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.detailDir
|
||||
}
|
||||
|
||||
func (s *Store) Err() error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() (File, error) {
|
||||
if s == nil {
|
||||
return File{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return File{}, s.err
|
||||
}
|
||||
return cloneFile(s.state), nil
|
||||
}
|
||||
|
||||
func (s *Store) Enabled() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
return s.state.Limit != DisabledLimit
|
||||
}
|
||||
|
||||
func (s *Store) Get(id string) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
item, ok := s.details[strings.TrimSpace(id)]
|
||||
if !ok {
|
||||
return Entry{}, errors.New("chat history entry not found")
|
||||
}
|
||||
return cloneEntry(item), nil
|
||||
}
|
||||
|
||||
func (s *Store) Start(params StartParams) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
if s.state.Limit == DisabledLimit {
|
||||
return Entry{}, ErrDisabled
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
revision := s.nextRevisionLocked()
|
||||
entry := Entry{
|
||||
ID: "chat_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
Revision: revision,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Status: "streaming",
|
||||
CallerID: strings.TrimSpace(params.CallerID),
|
||||
AccountID: strings.TrimSpace(params.AccountID),
|
||||
Model: strings.TrimSpace(params.Model),
|
||||
Stream: params.Stream,
|
||||
UserInput: strings.TrimSpace(params.UserInput),
|
||||
Messages: cloneMessages(params.Messages),
|
||||
FinalPrompt: strings.TrimSpace(params.FinalPrompt),
|
||||
}
|
||||
s.details[entry.ID] = entry
|
||||
s.markDetailDirtyLocked(entry.ID)
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return cloneEntry(entry), err
|
||||
}
|
||||
return cloneEntry(entry), nil
|
||||
}
|
||||
|
||||
func (s *Store) Update(id string, params UpdateParams) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
target := strings.TrimSpace(id)
|
||||
if target == "" {
|
||||
return Entry{}, errors.New("history id is required")
|
||||
}
|
||||
item, ok := s.details[target]
|
||||
if !ok {
|
||||
return Entry{}, errors.New("chat history entry not found")
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
item.Revision = s.nextRevisionLocked()
|
||||
item.UpdatedAt = now
|
||||
if params.Status != "" {
|
||||
item.Status = params.Status
|
||||
}
|
||||
item.ReasoningContent = params.ReasoningContent
|
||||
item.Content = params.Content
|
||||
item.Error = strings.TrimSpace(params.Error)
|
||||
item.StatusCode = params.StatusCode
|
||||
item.ElapsedMs = params.ElapsedMs
|
||||
item.FinishReason = strings.TrimSpace(params.FinishReason)
|
||||
if params.Usage != nil {
|
||||
item.Usage = cloneMap(params.Usage)
|
||||
}
|
||||
if params.Completed {
|
||||
item.CompletedAt = now
|
||||
}
|
||||
s.details[target] = item
|
||||
s.markDetailDirtyLocked(target)
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
return cloneEntry(item), nil
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id string) error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
target := strings.TrimSpace(id)
|
||||
if target == "" {
|
||||
return errors.New("history id is required")
|
||||
}
|
||||
if _, ok := s.details[target]; !ok {
|
||||
return errors.New("chat history entry not found")
|
||||
}
|
||||
s.markDetailDeletedLocked(target)
|
||||
delete(s.details, target)
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Clear() error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
for id := range s.details {
|
||||
s.markDetailDeletedLocked(id)
|
||||
}
|
||||
s.details = map[string]Entry{}
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) SetLimit(limit int) (File, error) {
|
||||
if s == nil {
|
||||
return File{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return File{}, s.err
|
||||
}
|
||||
if !isAllowedLimit(limit) {
|
||||
return File{}, fmt.Errorf("unsupported chat history limit: %d", limit)
|
||||
}
|
||||
s.state.Limit = limit
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return cloneFile(s.state), nil
|
||||
}
|
||||
|
||||
func (s *Store) loadLocked() error {
|
||||
if strings.TrimSpace(s.path) == "" {
|
||||
return errors.New("chat history path is required")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil && filepath.Dir(s.path) != "." {
|
||||
return fmt.Errorf("create chat history dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(s.detailDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history detail dir: %w", err)
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if saveErr := s.saveLocked(); saveErr != nil {
|
||||
config.Logger.Warn("[chat_history] bootstrap write failed", "path", s.path, "error", saveErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read chat history index: %w", err)
|
||||
}
|
||||
|
||||
legacy, legacyOK, legacyErr := parseLegacy(raw)
|
||||
if legacyErr != nil {
|
||||
return legacyErr
|
||||
}
|
||||
if legacyOK {
|
||||
s.loadLegacyLocked(legacy)
|
||||
if err := s.saveLocked(); err != nil {
|
||||
config.Logger.Warn("[chat_history] legacy migration writeback failed", "path", s.path, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var state File
|
||||
if err := json.Unmarshal(raw, &state); err != nil {
|
||||
return fmt.Errorf("decode chat history index: %w", err)
|
||||
}
|
||||
if state.Version == 0 {
|
||||
state.Version = FileVersion
|
||||
}
|
||||
if !isAllowedLimit(state.Limit) {
|
||||
state.Limit = DefaultLimit
|
||||
}
|
||||
s.state = cloneFile(state)
|
||||
s.details = map[string]Entry{}
|
||||
for _, item := range state.Items {
|
||||
detail, err := readDetailFile(filepath.Join(s.detailDir, item.ID+".json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.details[item.ID] = detail
|
||||
}
|
||||
s.rebuildIndexLocked()
|
||||
if saveErr := s.saveLocked(); saveErr != nil {
|
||||
config.Logger.Warn("[chat_history] index rewrite failed", "path", s.path, "error", saveErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) loadLegacyLocked(legacy legacyFile) {
|
||||
s.state.Version = FileVersion
|
||||
s.state.Limit = legacy.Limit
|
||||
if !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
s.details = map[string]Entry{}
|
||||
s.dirty = map[string]struct{}{}
|
||||
s.deleted = map[string]struct{}{}
|
||||
maxRevision := int64(0)
|
||||
for _, item := range legacy.Items {
|
||||
if strings.TrimSpace(item.ID) == "" {
|
||||
continue
|
||||
}
|
||||
item.Messages = cloneMessages(item.Messages)
|
||||
if item.Revision == 0 {
|
||||
if item.UpdatedAt > 0 {
|
||||
item.Revision = item.UpdatedAt
|
||||
} else {
|
||||
item.Revision = time.Now().UnixNano()
|
||||
}
|
||||
}
|
||||
if item.Revision > maxRevision {
|
||||
maxRevision = item.Revision
|
||||
}
|
||||
s.details[item.ID] = item
|
||||
s.markDetailDirtyLocked(item.ID)
|
||||
}
|
||||
s.state.Revision = maxRevision
|
||||
s.rebuildIndexLocked()
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
s.state.Version = FileVersion
|
||||
if !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
s.rebuildIndexLocked()
|
||||
|
||||
if err := os.MkdirAll(s.detailDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history detail dir: %w", err)
|
||||
}
|
||||
for _, id := range sortedDetailIDs(s.deleted) {
|
||||
path := filepath.Join(s.detailDir, id+".json")
|
||||
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("remove stale chat history detail: %w", err)
|
||||
}
|
||||
}
|
||||
for _, id := range sortedDetailIDs(s.dirty) {
|
||||
item, ok := s.details[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(s.detailDir, id+".json")
|
||||
payload, err := json.MarshalIndent(detailEnvelope{
|
||||
Version: FileVersion,
|
||||
Item: item,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chat history detail: %w", err)
|
||||
}
|
||||
if err := writeFileAtomic(path, append(payload, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.MarshalIndent(s.state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chat history index: %w", err)
|
||||
}
|
||||
if err := writeFileAtomic(s.path, append(payload, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
s.clearPendingDetailChangesLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) rebuildIndexLocked() {
|
||||
summaries := make([]SummaryEntry, 0, len(s.details))
|
||||
for _, item := range s.details {
|
||||
summaries = append(summaries, summaryFromEntry(item))
|
||||
}
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
if summaries[i].UpdatedAt == summaries[j].UpdatedAt {
|
||||
return summaries[i].CreatedAt > summaries[j].CreatedAt
|
||||
}
|
||||
return summaries[i].UpdatedAt > summaries[j].UpdatedAt
|
||||
})
|
||||
if s.state.Limit < DisabledLimit || !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
if s.state.Limit == DisabledLimit {
|
||||
s.state.Items = summaries
|
||||
return
|
||||
}
|
||||
if len(summaries) > s.state.Limit {
|
||||
keep := make(map[string]struct{}, s.state.Limit)
|
||||
for _, item := range summaries[:s.state.Limit] {
|
||||
keep[item.ID] = struct{}{}
|
||||
}
|
||||
for id := range s.details {
|
||||
if _, ok := keep[id]; !ok {
|
||||
s.markDetailDeletedLocked(id)
|
||||
delete(s.details, id)
|
||||
}
|
||||
}
|
||||
summaries = summaries[:s.state.Limit]
|
||||
}
|
||||
s.state.Items = summaries
|
||||
}
|
||||
|
||||
func (s *Store) nextRevisionLocked() int64 {
|
||||
next := time.Now().UnixNano()
|
||||
if next <= s.state.Revision {
|
||||
next = s.state.Revision + 1
|
||||
}
|
||||
s.state.Revision = next
|
||||
return next
|
||||
}
|
||||
|
||||
func summaryFromEntry(item Entry) SummaryEntry {
|
||||
return SummaryEntry{
|
||||
ID: item.ID,
|
||||
Revision: item.Revision,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
CompletedAt: item.CompletedAt,
|
||||
Status: item.Status,
|
||||
CallerID: item.CallerID,
|
||||
AccountID: item.AccountID,
|
||||
Model: item.Model,
|
||||
Stream: item.Stream,
|
||||
UserInput: item.UserInput,
|
||||
Preview: buildPreview(item),
|
||||
StatusCode: item.StatusCode,
|
||||
ElapsedMs: item.ElapsedMs,
|
||||
FinishReason: item.FinishReason,
|
||||
DetailRevision: item.Revision,
|
||||
}
|
||||
}
|
||||
|
||||
func buildPreview(item Entry) string {
|
||||
candidate := strings.TrimSpace(item.Content)
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.ReasoningContent)
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.Error)
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.UserInput)
|
||||
}
|
||||
if len(candidate) > defaultPreviewAt {
|
||||
return candidate[:defaultPreviewAt] + "..."
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
func readDetailFile(path string) (Entry, error) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Entry{}, fmt.Errorf("read chat history detail: %w", err)
|
||||
}
|
||||
var env detailEnvelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
return Entry{}, fmt.Errorf("decode chat history detail: %w", err)
|
||||
}
|
||||
return cloneEntry(env.Item), nil
|
||||
}
|
||||
|
||||
func parseLegacy(raw []byte) (legacyFile, bool, error) {
|
||||
var legacy legacyFile
|
||||
if err := json.Unmarshal(raw, &legacy); err != nil {
|
||||
return legacyFile{}, false, nil
|
||||
}
|
||||
if len(legacy.Items) == 0 {
|
||||
return legacy, false, nil
|
||||
}
|
||||
var probe legacyProbe
|
||||
if err := json.Unmarshal(raw, &probe); err == nil {
|
||||
for _, item := range probe.Items {
|
||||
if _, ok := item["detail_revision"]; ok {
|
||||
return legacy, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return legacy, true, nil
|
||||
}
|
||||
|
||||
func writeFileAtomic(path string, body []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if dir == "" {
|
||||
dir = "."
|
||||
}
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history dir: %w", err)
|
||||
}
|
||||
}
|
||||
tmpFile, err := os.CreateTemp(dir, ".chat-history-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp chat history: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := func() error {
|
||||
if err := os.Remove(tmpPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("remove temp chat history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
withCleanup := func(primary error, closeErr error) error {
|
||||
errs := []error{primary}
|
||||
if closeErr != nil {
|
||||
errs = append(errs, fmt.Errorf("close temp chat history: %w", closeErr))
|
||||
}
|
||||
if cleanupErr := cleanup(); cleanupErr != nil {
|
||||
errs = append(errs, cleanupErr)
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
if _, err := tmpFile.Write(body); err != nil {
|
||||
return withCleanup(fmt.Errorf("write temp chat history: %w", err), tmpFile.Close())
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
return withCleanup(fmt.Errorf("sync temp chat history: %w", err), tmpFile.Close())
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
if cleanupErr := cleanup(); cleanupErr != nil {
|
||||
return errors.Join(fmt.Errorf("close temp chat history: %w", err), cleanupErr)
|
||||
}
|
||||
return fmt.Errorf("close temp chat history: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
if cleanupErr := cleanup(); cleanupErr != nil {
|
||||
return errors.Join(fmt.Errorf("promote temp chat history: %w", err), cleanupErr)
|
||||
}
|
||||
return fmt.Errorf("promote temp chat history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListETag(revision int64) string {
|
||||
return fmt.Sprintf(`W/"chat-history-list-%d"`, revision)
|
||||
}
|
||||
|
||||
func DetailETag(id string, revision int64) string {
|
||||
return fmt.Sprintf(`W/"chat-history-detail-%s-%d"`, strings.TrimSpace(id), revision)
|
||||
}
|
||||
|
||||
func isAllowedLimit(limit int) bool {
|
||||
_, ok := allowedLimits[limit]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Store) markDetailDirtyLocked(id string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if s.dirty == nil {
|
||||
s.dirty = map[string]struct{}{}
|
||||
}
|
||||
if s.deleted == nil {
|
||||
s.deleted = map[string]struct{}{}
|
||||
}
|
||||
s.dirty[id] = struct{}{}
|
||||
delete(s.deleted, id)
|
||||
}
|
||||
|
||||
func (s *Store) markDetailDeletedLocked(id string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if s.dirty == nil {
|
||||
s.dirty = map[string]struct{}{}
|
||||
}
|
||||
if s.deleted == nil {
|
||||
s.deleted = map[string]struct{}{}
|
||||
}
|
||||
s.deleted[id] = struct{}{}
|
||||
delete(s.dirty, id)
|
||||
}
|
||||
|
||||
func (s *Store) clearPendingDetailChangesLocked() {
|
||||
s.dirty = map[string]struct{}{}
|
||||
s.deleted = map[string]struct{}{}
|
||||
}
|
||||
|
||||
func sortedDetailIDs(ids map[string]struct{}) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
out = append(out, id)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneFile(in File) File {
|
||||
out := File{
|
||||
Version: in.Version,
|
||||
Limit: in.Limit,
|
||||
Revision: in.Revision,
|
||||
Items: make([]SummaryEntry, len(in.Items)),
|
||||
}
|
||||
copy(out.Items, in.Items)
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneEntry(item Entry) Entry {
|
||||
item.Usage = cloneMap(item.Usage)
|
||||
item.Messages = cloneMessages(item.Messages)
|
||||
return item
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneMessages(messages []Message) []Message {
|
||||
if len(messages) == 0 {
|
||||
return []Message{}
|
||||
}
|
||||
out := make([]Message, len(messages))
|
||||
copy(out, messages)
|
||||
return out
|
||||
}
|
||||
483
internal/chathistory/store_test.go
Normal file
483
internal/chathistory/store_test.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package chathistory
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func blockDetailDir(t *testing.T, detailDir string) func() {
|
||||
t.Helper()
|
||||
blockedDir := detailDir + ".blocked"
|
||||
if err := os.RemoveAll(blockedDir); err != nil {
|
||||
t.Fatalf("remove blocked detail dir failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(detailDir, blockedDir); err != nil {
|
||||
t.Fatalf("move detail dir aside failed: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocked detail path failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(detailDir, []byte("blocked"), 0o644); err != nil {
|
||||
t.Fatalf("write blocked detail path failed: %v", err)
|
||||
}
|
||||
var once sync.Once
|
||||
return func() {
|
||||
t.Helper()
|
||||
once.Do(func() {
|
||||
if err := os.RemoveAll(detailDir); err != nil {
|
||||
t.Fatalf("remove blocking detail path failed: %v", err)
|
||||
}
|
||||
if err := os.Rename(blockedDir, detailDir); err != nil {
|
||||
t.Fatalf("restore detail dir failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCreatesAndPersistsEntries(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
started, err := store.Start(StartParams{
|
||||
CallerID: "caller:abc",
|
||||
AccountID: "user@example.com",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start entry failed: %v", err)
|
||||
}
|
||||
|
||||
updated, err := store.Update(started.ID, UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: "thinking",
|
||||
Content: "answer",
|
||||
StatusCode: 200,
|
||||
ElapsedMs: 321,
|
||||
FinishReason: "stop",
|
||||
Usage: map[string]any{"total_tokens": 9},
|
||||
Completed: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("update entry failed: %v", err)
|
||||
}
|
||||
if updated.Status != "success" || updated.Content != "answer" {
|
||||
t.Fatalf("unexpected updated entry: %#v", updated)
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != DefaultLimit {
|
||||
t.Fatalf("unexpected default limit: %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one item, got %d", len(snapshot.Items))
|
||||
}
|
||||
if snapshot.Items[0].CompletedAt == 0 {
|
||||
t.Fatalf("expected completed_at to be populated")
|
||||
}
|
||||
if snapshot.Items[0].Preview != "answer" {
|
||||
t.Fatalf("expected summary preview=answer, got %#v", snapshot.Items[0])
|
||||
}
|
||||
|
||||
reloaded := New(path)
|
||||
reloadedSnapshot, err := reloaded.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("reload snapshot failed: %v", err)
|
||||
}
|
||||
if len(reloadedSnapshot.Items) != 1 {
|
||||
t.Fatalf("unexpected reloaded summaries: %#v", reloadedSnapshot.Items)
|
||||
}
|
||||
full, err := reloaded.Get(started.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "answer" {
|
||||
t.Fatalf("expected detail content=answer, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreTrimsToConfiguredLimit(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
if _, err := store.SetLimit(10); err != nil {
|
||||
t.Fatalf("set limit failed: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
entry, err := store.Start(StartParams{Model: "deepseek-chat", UserInput: "msg"})
|
||||
if err != nil {
|
||||
t.Fatalf("start %d failed: %v", i, err)
|
||||
}
|
||||
if _, err := store.Update(entry.ID, UpdateParams{Status: "success", Content: "ok", Completed: true}); err != nil {
|
||||
t.Fatalf("update %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 10 {
|
||||
t.Fatalf("expected 10 items, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDeleteClearAndLimitValidation(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
entry, err := store.Start(StartParams{UserInput: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("start failed: %v", err)
|
||||
}
|
||||
if err := store.Delete(entry.ID); err != nil {
|
||||
t.Fatalf("delete failed: %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected empty items after delete, got %d", len(snapshot.Items))
|
||||
}
|
||||
if _, err := store.SetLimit(999); err == nil {
|
||||
t.Fatalf("expected invalid limit error")
|
||||
}
|
||||
if err := store.Clear(); err != nil {
|
||||
t.Fatalf("clear failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDisablePreservesHistoryAndBlocksNewEntries(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
entry, err := store.Start(StartParams{UserInput: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("start failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(entry.ID, UpdateParams{Status: "success", Content: "world", Completed: true}); err != nil {
|
||||
t.Fatalf("update failed: %v", err)
|
||||
}
|
||||
|
||||
snapshot, err := store.SetLimit(DisabledLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("disable failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != DisabledLimit {
|
||||
t.Fatalf("expected disabled limit, got %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected disabled mode to preserve summaries, got %d", len(snapshot.Items))
|
||||
}
|
||||
if store.Enabled() {
|
||||
t.Fatalf("expected store to report disabled")
|
||||
}
|
||||
if _, err := store.Start(StartParams{UserInput: "later"}); err != ErrDisabled {
|
||||
t.Fatalf("expected ErrDisabled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreConcurrentUpdatesKeepSplitFilesValid(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
entry, err := store.Start(StartParams{
|
||||
CallerID: "caller:test",
|
||||
Model: "deepseek-chat",
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("start failed: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = store.Update(entry.ID, UpdateParams{
|
||||
Status: "success",
|
||||
Content: "answer",
|
||||
ElapsedMs: int64(idx),
|
||||
Completed: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("update failed: %v", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 8 {
|
||||
t.Fatalf("expected 8 items, got %d", len(snapshot.Items))
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read index failed: %v", err)
|
||||
}
|
||||
var persisted File
|
||||
if err := json.Unmarshal(raw, &persisted); err != nil {
|
||||
t.Fatalf("persisted index is invalid json: %v", err)
|
||||
}
|
||||
if len(persisted.Items) != 8 {
|
||||
t.Fatalf("expected persisted items=8, got %d", len(persisted.Items))
|
||||
}
|
||||
|
||||
detailFiles, err := os.ReadDir(path + ".d")
|
||||
if err != nil {
|
||||
t.Fatalf("read detail dir failed: %v", err)
|
||||
}
|
||||
if len(detailFiles) != 8 {
|
||||
t.Fatalf("expected 8 detail files, got %d", len(detailFiles))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAutoMigratesLegacyMonolith(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: "chat_legacy",
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "success",
|
||||
UserInput: "hello",
|
||||
Content: "world",
|
||||
ReasoningContent: "thinking",
|
||||
}},
|
||||
}
|
||||
body, _ := json.MarshalIndent(legacy, "", " ")
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected legacy migration success, got %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one migrated summary, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get("chat_legacy")
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "world" {
|
||||
t.Fatalf("expected migrated detail content preserved, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAutoMigratesMetadataOnlyLegacyMonolith(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: "chat_metadata_only",
|
||||
Revision: 0,
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "error",
|
||||
CallerID: "caller:test",
|
||||
AccountID: "acct:test",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
Error: "boom",
|
||||
StatusCode: 500,
|
||||
ElapsedMs: 12,
|
||||
FinishReason: "error",
|
||||
}},
|
||||
}
|
||||
body, _ := json.MarshalIndent(legacy, "", " ")
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected legacy metadata-only migration success, got %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one migrated summary, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get("chat_metadata_only")
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Error != "boom" || full.UserInput != "hello" {
|
||||
t.Fatalf("expected metadata-only legacy fields preserved, got %#v", full)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(store.DetailDir(), "chat_metadata_only.json")); err != nil {
|
||||
t.Fatalf("expected migrated detail file to exist: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreLegacyMigrationBestEffortWhenRewriteFails(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
longID := "chat_" + strings.Repeat("x", 320)
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: longID,
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "success",
|
||||
UserInput: "hello",
|
||||
Content: "world",
|
||||
}},
|
||||
}
|
||||
body, err := json.MarshalIndent(legacy, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy file failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected store to stay usable after migration writeback failure, got %v", err)
|
||||
}
|
||||
if !store.Enabled() {
|
||||
t.Fatal("expected store to remain enabled after best-effort migration")
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 || snapshot.Items[0].ID != longID {
|
||||
t.Fatalf("unexpected snapshot after best-effort migration: %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get(longID)
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "world" {
|
||||
t.Fatalf("expected migrated content to stay in memory, got %#v", full)
|
||||
}
|
||||
if _, statErr := os.Stat(filepath.Join(store.DetailDir(), longID+".json")); statErr == nil {
|
||||
t.Fatal("expected detail write to fail for overlong legacy id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreTransientPersistenceFailureDoesNotLatch(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
first, err := store.Start(StartParams{UserInput: "first"})
|
||||
if err != nil {
|
||||
t.Fatalf("start first failed: %v", err)
|
||||
}
|
||||
restore := blockDetailDir(t, store.DetailDir())
|
||||
t.Cleanup(restore)
|
||||
|
||||
blocked, err := store.Start(StartParams{UserInput: "blocked"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected start failure while detail dir is blocked")
|
||||
}
|
||||
if blocked.ID == "" {
|
||||
t.Fatalf("expected in-memory entry from failed start")
|
||||
}
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("transient start failure should not latch store error: %v", err)
|
||||
}
|
||||
if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "one", Completed: true}); err == nil {
|
||||
t.Fatalf("expected update failure while detail dir is blocked")
|
||||
}
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("transient update failure should not latch store error: %v", err)
|
||||
}
|
||||
|
||||
restore()
|
||||
|
||||
if _, err := store.Update(blocked.ID, UpdateParams{Status: "success", Content: "two", Completed: true}); err != nil {
|
||||
t.Fatalf("update after restore failed: %v", err)
|
||||
}
|
||||
if _, err := store.Start(StartParams{UserInput: "later"}); err != nil {
|
||||
t.Fatalf("start after restore failed: %v", err)
|
||||
}
|
||||
full, err := store.Get(blocked.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get restored entry failed: %v", err)
|
||||
}
|
||||
if full.Content != "two" || full.Status != "success" {
|
||||
t.Fatalf("expected restored entry persisted, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreWritesOnlyChangedDetailFiles(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
first, err := store.Start(StartParams{UserInput: "one"})
|
||||
if err != nil {
|
||||
t.Fatalf("start first failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "first", Completed: true}); err != nil {
|
||||
t.Fatalf("update first failed: %v", err)
|
||||
}
|
||||
second, err := store.Start(StartParams{UserInput: "two"})
|
||||
if err != nil {
|
||||
t.Fatalf("start second failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(second.ID, UpdateParams{Status: "success", Content: "second", Completed: true}); err != nil {
|
||||
t.Fatalf("update second failed: %v", err)
|
||||
}
|
||||
|
||||
firstPath := filepath.Join(store.DetailDir(), first.ID+".json")
|
||||
secondPath := filepath.Join(store.DetailDir(), second.ID+".json")
|
||||
beforeFirst, err := os.ReadFile(firstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first detail before update failed: %v", err)
|
||||
}
|
||||
beforeSecond, err := os.ReadFile(secondPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second detail before update failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "first-updated", Completed: true}); err != nil {
|
||||
t.Fatalf("update first again failed: %v", err)
|
||||
}
|
||||
|
||||
afterFirst, err := os.ReadFile(firstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first detail after update failed: %v", err)
|
||||
}
|
||||
afterSecond, err := os.ReadFile(secondPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second detail after update failed: %v", err)
|
||||
}
|
||||
|
||||
if bytes.Equal(beforeFirst, afterFirst) {
|
||||
t.Fatalf("expected first detail file to change after update")
|
||||
}
|
||||
if !bytes.Equal(beforeSecond, afterSecond) {
|
||||
t.Fatalf("expected untouched detail file to remain byte-identical")
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"ds2api/internal/toolcall"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
@@ -65,55 +63,6 @@ func TestGoCompatSSEFixtures(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("glob toolcall fixtures failed: %v", err)
|
||||
}
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no toolcall fixtures found")
|
||||
}
|
||||
for _, fixturePath := range files {
|
||||
name := trimExt(filepath.Base(fixturePath))
|
||||
expectedPath := compatPath("expected", "toolcalls_"+name+".json")
|
||||
|
||||
var fixture struct {
|
||||
Text string `json:"text"`
|
||||
ToolNames []string `json:"tool_names"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
var expected struct {
|
||||
Calls []toolcall.ParsedToolCall `json:"calls"`
|
||||
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
|
||||
RejectedByPolicy bool `json:"rejectedByPolicy"`
|
||||
RejectedToolNames []string `json:"rejectedToolNames"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
var got toolcall.ToolCallParseResult
|
||||
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||
case "standalone":
|
||||
got = toolcall.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
default:
|
||||
got = toolcall.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
}
|
||||
if got.Calls == nil {
|
||||
got.Calls = []toolcall.ParsedToolCall{}
|
||||
}
|
||||
if got.RejectedToolNames == nil {
|
||||
got.RejectedToolNames = []string{}
|
||||
}
|
||||
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
|
||||
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
|
||||
got.RejectedByPolicy != expected.RejectedByPolicy ||
|
||||
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoCompatTokenFixtures(t *testing.T) {
|
||||
var fixture struct {
|
||||
Cases []struct {
|
||||
|
||||
@@ -17,6 +17,9 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if len(c.Keys) > 0 {
|
||||
m["keys"] = c.Keys
|
||||
}
|
||||
if len(c.APIKeys) > 0 {
|
||||
m["api_keys"] = c.APIKeys
|
||||
}
|
||||
if len(c.Accounts) > 0 {
|
||||
m["accounts"] = c.Accounts
|
||||
}
|
||||
@@ -48,6 +51,9 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
m["auto_delete"] = c.AutoDelete
|
||||
if c.HistorySplit.Enabled != nil || c.HistorySplit.TriggerAfterTurns != nil {
|
||||
m["history_split"] = c.HistorySplit
|
||||
}
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
@@ -69,6 +75,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.Keys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "api_keys":
|
||||
if err := json.Unmarshal(v, &c.APIKeys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "accounts":
|
||||
if err := json.Unmarshal(v, &c.Accounts); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -115,6 +125,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.AutoDelete); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "history_split":
|
||||
if err := json.Unmarshal(v, &c.HistorySplit); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -130,12 +144,14 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.NormalizeCredentials()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
APIKeys: slices.Clone(c.APIKeys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
Proxies: slices.Clone(c.Proxies),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
@@ -147,9 +163,13 @@ func (c Config) Clone() Config {
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
StripReferenceMarkers: cloneBoolPtr(c.Compat.StripReferenceMarkers),
|
||||
},
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
AutoDelete: c.AutoDelete,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
AutoDelete: c.AutoDelete,
|
||||
HistorySplit: HistorySplitConfig{
|
||||
Enabled: cloneBoolPtr(c.HistorySplit.Enabled),
|
||||
TriggerAfterTurns: cloneIntPtr(c.HistorySplit.TriggerAfterTurns),
|
||||
},
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
@@ -179,6 +199,14 @@ func cloneBoolPtr(in *bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func cloneIntPtr(in *int) *int {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
v := *in
|
||||
return &v
|
||||
}
|
||||
|
||||
func parseConfigString(raw string) (Config, error) {
|
||||
var cfg Config
|
||||
candidates := []string{raw}
|
||||
|
||||
@@ -8,24 +8,28 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
Proxies []Proxy `json:"proxies,omitempty"`
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
ModelAliases map[string]string `json:"model_aliases,omitempty"`
|
||||
Admin AdminConfig `json:"admin,omitempty"`
|
||||
Runtime RuntimeConfig `json:"runtime,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
Proxies []Proxy `json:"proxies,omitempty"`
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
ModelAliases map[string]string `json:"model_aliases,omitempty"`
|
||||
Admin AdminConfig `json:"admin,omitempty"`
|
||||
Runtime RuntimeConfig `json:"runtime,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||
HistorySplit HistorySplitConfig `json:"history_split"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
@@ -33,6 +37,12 @@ type Account struct {
|
||||
ProxyID string `json:"proxy_id,omitempty"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -73,6 +83,25 @@ func (c *Config) ClearAccountTokens() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) NormalizeCredentials() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
normalizedAPIKeys := normalizeAPIKeys(c.APIKeys)
|
||||
if len(normalizedAPIKeys) > 0 {
|
||||
c.APIKeys = normalizedAPIKeys
|
||||
c.Keys = apiKeysToStrings(c.APIKeys)
|
||||
} else {
|
||||
c.Keys = normalizeKeys(c.Keys)
|
||||
c.APIKeys = apiKeysFromStrings(c.Keys, nil)
|
||||
}
|
||||
|
||||
for i := range c.Accounts {
|
||||
c.Accounts[i].Name = strings.TrimSpace(c.Accounts[i].Name)
|
||||
c.Accounts[i].Remark = strings.TrimSpace(c.Accounts[i].Remark)
|
||||
}
|
||||
}
|
||||
|
||||
// DropInvalidAccounts removes accounts that cannot be addressed by admin APIs
|
||||
// (no email and no normalizable mobile). This prevents legacy token-only
|
||||
// records from becoming orphaned empty entries after token stripping.
|
||||
@@ -120,3 +149,8 @@ type AutoDeleteConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Sessions bool `json:"sessions,omitempty"`
|
||||
}
|
||||
|
||||
type HistorySplitConfig struct {
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
TriggerAfterTurns *int `json:"trigger_after_turns,omitempty"`
|
||||
}
|
||||
|
||||
@@ -154,6 +154,10 @@ func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
AutoDelete: AutoDeleteConfig{
|
||||
Mode: "single",
|
||||
},
|
||||
HistorySplit: HistorySplitConfig{
|
||||
Enabled: &trueVal,
|
||||
TriggerAfterTurns: func() *int { v := 2; return &v }(),
|
||||
},
|
||||
Runtime: RuntimeConfig{
|
||||
TokenRefreshIntervalHours: 12,
|
||||
},
|
||||
@@ -193,6 +197,12 @@ func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
if decoded.AutoDelete.Mode != "single" {
|
||||
t.Fatalf("unexpected auto delete mode: %#v", decoded.AutoDelete.Mode)
|
||||
}
|
||||
if decoded.HistorySplit.Enabled == nil || !*decoded.HistorySplit.Enabled {
|
||||
t.Fatalf("unexpected history split enabled: %#v", decoded.HistorySplit.Enabled)
|
||||
}
|
||||
if decoded.HistorySplit.TriggerAfterTurns == nil || *decoded.HistorySplit.TriggerAfterTurns != 2 {
|
||||
t.Fatalf("unexpected history split trigger_after_turns: %#v", decoded.HistorySplit.TriggerAfterTurns)
|
||||
}
|
||||
if decoded.Compat.WideInputStrictOutput == nil || !*decoded.Compat.WideInputStrictOutput {
|
||||
t.Fatalf("unexpected compat wide_input_strict_output: %#v", decoded.Compat.WideInputStrictOutput)
|
||||
}
|
||||
@@ -249,6 +259,8 @@ func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) {
|
||||
|
||||
func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
falseVal := false
|
||||
trueVal := true
|
||||
turns := 2
|
||||
cfg := Config{
|
||||
Keys: []string{"key1"},
|
||||
Accounts: []Account{{Email: "user@test.com", Token: "token"}},
|
||||
@@ -258,6 +270,10 @@ func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
Compat: CompatConfig{
|
||||
StripReferenceMarkers: &falseVal,
|
||||
},
|
||||
HistorySplit: HistorySplitConfig{
|
||||
Enabled: &trueVal,
|
||||
TriggerAfterTurns: &turns,
|
||||
},
|
||||
AdditionalFields: map[string]any{"custom": "value"},
|
||||
}
|
||||
|
||||
@@ -270,6 +286,12 @@ func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
if cfg.Compat.StripReferenceMarkers != nil {
|
||||
*cfg.Compat.StripReferenceMarkers = true
|
||||
}
|
||||
if cfg.HistorySplit.Enabled != nil {
|
||||
*cfg.HistorySplit.Enabled = false
|
||||
}
|
||||
if cfg.HistorySplit.TriggerAfterTurns != nil {
|
||||
*cfg.HistorySplit.TriggerAfterTurns = 5
|
||||
}
|
||||
|
||||
// Cloned should not be affected
|
||||
if cloned.Keys[0] != "key1" {
|
||||
@@ -284,6 +306,12 @@ func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
if cloned.Compat.StripReferenceMarkers == nil || *cloned.Compat.StripReferenceMarkers {
|
||||
t.Fatalf("clone compat was affected: %#v", cloned.Compat.StripReferenceMarkers)
|
||||
}
|
||||
if cloned.HistorySplit.Enabled == nil || !*cloned.HistorySplit.Enabled {
|
||||
t.Fatalf("clone history split enabled was affected: %#v", cloned.HistorySplit.Enabled)
|
||||
}
|
||||
if cloned.HistorySplit.TriggerAfterTurns == nil || *cloned.HistorySplit.TriggerAfterTurns != 2 {
|
||||
t.Fatalf("clone history split trigger was affected: %#v", cloned.HistorySplit.TriggerAfterTurns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCloneNilMaps(t *testing.T) {
|
||||
@@ -529,6 +557,101 @@ func TestStoreUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateReconcilesAPIKeyMutations(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
store := LoadStore()
|
||||
|
||||
if err := store.Update(func(cfg *Config) error {
|
||||
cfg.APIKeys = append(cfg.APIKeys, APIKey{Key: "k2", Name: "secondary", Remark: "staging"})
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("add api key failed: %v", err)
|
||||
}
|
||||
|
||||
snap := store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "k1" || snap.Keys[1] != "k2" {
|
||||
t.Fatalf("unexpected keys after api key add: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys length after add: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("metadata for existing key was lost: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Name != "secondary" || snap.APIKeys[1].Remark != "staging" {
|
||||
t.Fatalf("metadata for new key was lost: %#v", snap.APIKeys[1])
|
||||
}
|
||||
|
||||
if err := store.Update(func(cfg *Config) error {
|
||||
cfg.APIKeys = append([]APIKey(nil), cfg.APIKeys[1:]...)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("delete api key failed: %v", err)
|
||||
}
|
||||
|
||||
snap = store.Snapshot()
|
||||
if len(snap.Keys) != 1 || snap.Keys[0] != "k2" {
|
||||
t.Fatalf("unexpected keys after api key delete: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 1 || snap.APIKeys[0].Key != "k2" {
|
||||
t.Fatalf("unexpected api keys after delete: %#v", snap.APIKeys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateReconcilesLegacyKeyMutations(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"api_keys":[{"key":"k1","name":"primary","remark":"prod"}],
|
||||
"accounts":[]
|
||||
}`)
|
||||
store := LoadStore()
|
||||
|
||||
if err := store.Update(func(cfg *Config) error {
|
||||
cfg.Keys = append(cfg.Keys, "k2")
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("legacy key update failed: %v", err)
|
||||
}
|
||||
|
||||
snap := store.Snapshot()
|
||||
if len(snap.Keys) != 2 || snap.Keys[0] != "k1" || snap.Keys[1] != "k2" {
|
||||
t.Fatalf("unexpected keys after legacy update: %#v", snap.Keys)
|
||||
}
|
||||
if len(snap.APIKeys) != 2 {
|
||||
t.Fatalf("unexpected api keys after legacy update: %#v", snap.APIKeys)
|
||||
}
|
||||
if snap.APIKeys[0].Name != "primary" || snap.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("metadata for preserved key was lost: %#v", snap.APIKeys[0])
|
||||
}
|
||||
if snap.APIKeys[1].Key != "k2" || snap.APIKeys[1].Name != "" || snap.APIKeys[1].Remark != "" {
|
||||
t.Fatalf("new legacy key should stay metadata-free: %#v", snap.APIKeys[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCredentialsPrefersStructuredAPIKeys(t *testing.T) {
|
||||
cfg := Config{
|
||||
Keys: []string{"legacy-key"},
|
||||
APIKeys: []APIKey{
|
||||
{Key: "structured-key", Name: "primary", Remark: "prod"},
|
||||
},
|
||||
}
|
||||
cfg.NormalizeCredentials()
|
||||
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "structured-key" {
|
||||
t.Fatalf("unexpected normalized keys: %#v", cfg.Keys)
|
||||
}
|
||||
if len(cfg.APIKeys) != 1 {
|
||||
t.Fatalf("unexpected normalized api keys: %#v", cfg.APIKeys)
|
||||
}
|
||||
if cfg.APIKeys[0].Key != "structured-key" || cfg.APIKeys[0].Name != "primary" || cfg.APIKeys[0].Remark != "prod" {
|
||||
t.Fatalf("unexpected structured api key metadata: %#v", cfg.APIKeys[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreClaudeMapping(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
|
||||
store := LoadStore()
|
||||
|
||||
158
internal/config/credentials.go
Normal file
158
internal/config/credentials.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Config) ReconcileCredentials(base Config) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
currKeys := normalizeKeys(c.Keys)
|
||||
currAPIKeys := normalizeAPIKeys(c.APIKeys)
|
||||
baseKeys := normalizeKeys(base.Keys)
|
||||
baseAPIKeys := normalizeAPIKeys(base.APIKeys)
|
||||
|
||||
keysChanged := !slices.Equal(currKeys, baseKeys)
|
||||
apiKeysChanged := !equalAPIKeys(currAPIKeys, baseAPIKeys)
|
||||
|
||||
if keysChanged && !apiKeysChanged {
|
||||
c.APIKeys = apiKeysFromStrings(currKeys, apiKeyMap(baseAPIKeys))
|
||||
} else {
|
||||
c.APIKeys = currAPIKeys
|
||||
}
|
||||
c.Keys = apiKeysToStrings(c.APIKeys)
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, key)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeAPIKeys(items []APIKey) []APIKey {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]APIKey, 0, len(items))
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, APIKey{
|
||||
Key: key,
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func apiKeysFromStrings(keys []string, meta map[string]APIKey) []APIKey {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]APIKey, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
if item, ok := meta[key]; ok {
|
||||
out = append(out, APIKey{
|
||||
Key: key,
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
})
|
||||
continue
|
||||
}
|
||||
out = append(out, APIKey{Key: key})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func apiKeysToStrings(items []APIKey) []string {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func apiKeyMap(items []APIKey) map[string]APIKey {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]APIKey, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := out[key]; ok {
|
||||
continue
|
||||
}
|
||||
out[key] = APIKey{
|
||||
Key: key,
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Remark: strings.TrimSpace(item.Remark),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalAPIKeys(a, b []APIKey) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
return slices.EqualFunc(a, b, func(x, y APIKey) bool {
|
||||
return strings.TrimSpace(x.Key) == strings.TrimSpace(y.Key) &&
|
||||
strings.TrimSpace(x.Name) == strings.TrimSpace(y.Name) &&
|
||||
strings.TrimSpace(x.Remark) == strings.TrimSpace(y.Remark)
|
||||
})
|
||||
}
|
||||
@@ -37,6 +37,10 @@ func RawStreamSampleRoot() string {
|
||||
return ResolvePath("DS2API_RAW_STREAM_SAMPLE_ROOT", "tests/raw_stream_samples")
|
||||
}
|
||||
|
||||
func ChatHistoryPath() string {
|
||||
return ResolvePath("DS2API_CHAT_HISTORY_PATH", "data/chat_history.json")
|
||||
}
|
||||
|
||||
func StaticAdminDir() string {
|
||||
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func LoadStoreWithError() (*Store, error) {
|
||||
|
||||
func loadStore() (*Store, error) {
|
||||
cfg, fromEnv, err := loadConfig()
|
||||
cfg.NormalizeCredentials()
|
||||
if validateErr := ValidateConfig(cfg); validateErr != nil {
|
||||
err = errors.Join(err, validateErr)
|
||||
}
|
||||
@@ -112,6 +113,7 @@ func loadConfigFromFile(path string) (Config, error) {
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.NormalizeCredentials()
|
||||
cfg.DropInvalidAccounts()
|
||||
if strings.Contains(string(content), `"test_status"`) && !IsVercel() {
|
||||
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
|
||||
@@ -207,6 +209,7 @@ func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
func (s *Store) Replace(cfg Config) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg.NormalizeCredentials()
|
||||
s.cfg = cfg.Clone()
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
@@ -215,10 +218,13 @@ func (s *Store) Replace(cfg Config) error {
|
||||
func (s *Store) Update(mutator func(*Config) error) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.cfg.Clone()
|
||||
base := s.cfg.Clone()
|
||||
cfg := base.Clone()
|
||||
if err := mutator(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.ReconcileCredentials(base)
|
||||
cfg.NormalizeCredentials()
|
||||
s.cfg = cfg
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
|
||||
@@ -174,3 +174,21 @@ func (s *Store) RuntimeTokenRefreshIntervalHours() int {
|
||||
func (s *Store) AutoDeleteSessions() bool {
|
||||
return s.AutoDeleteMode() != "none"
|
||||
}
|
||||
|
||||
func (s *Store) HistorySplitEnabled() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.HistorySplit.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *s.cfg.HistorySplit.Enabled
|
||||
}
|
||||
|
||||
func (s *Store) HistorySplitTriggerAfterTurns() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.HistorySplit.TriggerAfterTurns == nil || *s.cfg.HistorySplit.TriggerAfterTurns <= 0 {
|
||||
return 1
|
||||
}
|
||||
return *s.cfg.HistorySplit.TriggerAfterTurns
|
||||
}
|
||||
|
||||
27
internal/config/store_accessors_test.go
Normal file
27
internal/config/store_accessors_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestStoreHistorySplitAccessors(t *testing.T) {
|
||||
store := &Store{cfg: Config{}}
|
||||
if !store.HistorySplitEnabled() {
|
||||
t.Fatal("expected history split enabled by default")
|
||||
}
|
||||
if got := store.HistorySplitTriggerAfterTurns(); got != 1 {
|
||||
t.Fatalf("default history split trigger_after_turns=%d want=1", got)
|
||||
}
|
||||
|
||||
enabled := false
|
||||
turns := 3
|
||||
store.cfg.HistorySplit = HistorySplitConfig{
|
||||
Enabled: &enabled,
|
||||
TriggerAfterTurns: &turns,
|
||||
}
|
||||
|
||||
if store.HistorySplitEnabled() {
|
||||
t.Fatal("expected history split disabled after override")
|
||||
}
|
||||
if got := store.HistorySplitTriggerAfterTurns(); got != 3 {
|
||||
t.Fatalf("history split trigger_after_turns=%d want=3", got)
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,9 @@ func ValidateConfig(c Config) error {
|
||||
if err := ValidateAutoDeleteConfig(c.AutoDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateHistorySplitConfig(c.HistorySplit); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAccountProxyReferences(c.Accounts, c.Proxies); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -111,6 +114,15 @@ func ValidateAutoDeleteConfig(autoDelete AutoDeleteConfig) error {
|
||||
return ValidateAutoDeleteMode(autoDelete.Mode)
|
||||
}
|
||||
|
||||
func ValidateHistorySplitConfig(historySplit HistorySplitConfig) error {
|
||||
if historySplit.TriggerAfterTurns != nil {
|
||||
if err := ValidateIntRange("history_split.trigger_after_turns", *historySplit.TriggerAfterTurns, 1, 1000, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateIntRange(name string, value, min, max int, required bool) error {
|
||||
if value == 0 && !required {
|
||||
return nil
|
||||
|
||||
@@ -39,6 +39,13 @@ func TestValidateConfigRejectsInvalidValues(t *testing.T) {
|
||||
cfg: Config{AutoDelete: AutoDeleteConfig{Mode: "maybe"}},
|
||||
want: "auto_delete.mode",
|
||||
},
|
||||
{
|
||||
name: "history split",
|
||||
cfg: Config{HistorySplit: HistorySplitConfig{
|
||||
TriggerAfterTurns: intPtr(0),
|
||||
}},
|
||||
want: "history_split.trigger_after_turns",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -59,3 +66,5 @@ func TestValidateConfigAcceptsLegacyAutoDeleteSessions(t *testing.T) {
|
||||
t.Fatalf("expected legacy auto_delete.sessions config to remain valid, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
@@ -91,17 +91,25 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
||||
}
|
||||
|
||||
func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
return c.GetPowForTarget(ctx, a, DeepSeekCompletionTargetPath, maxAttempts)
|
||||
}
|
||||
|
||||
func (c *Client) GetPowForTarget(ctx context.Context, a *auth.RequestAuth, targetPath string, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
targetPath = strings.TrimSpace(targetPath)
|
||||
if targetPath == "" {
|
||||
targetPath = DeepSeekCompletionTargetPath
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, clients.fallback, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, clients.fallback, DeepSeekCreatePowURL, headers, map[string]any{"target_path": targetPath})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
|
||||
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID, "target_path", targetPath)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
@@ -117,7 +125,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
}
|
||||
return BuildPowHeader(challenge, answer)
|
||||
}
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID, "target_path", targetPath)
|
||||
if a.UseConfigToken {
|
||||
if !refreshed && shouldAttemptRefresh(status, code, bizCode, msg, bizMsg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
|
||||
@@ -51,6 +51,7 @@ func (c *Client) streamPost(ctx context.Context, doer trans.Doer, url string, he
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers = c.jsonHeaders(headers)
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
|
||||
188
internal/deepseek/client_file_status.go
Normal file
188
internal/deepseek/client_file_status.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
fileReadyPollAttempts = 60
|
||||
fileReadyPollInterval = time.Second
|
||||
fileReadyPollTimeout = 65 * time.Second
|
||||
)
|
||||
|
||||
var fileReadySleep = time.Sleep
|
||||
|
||||
func (c *Client) waitForUploadedFile(ctx context.Context, a *auth.RequestAuth, result *UploadFileResult) error {
|
||||
if result == nil || strings.TrimSpace(result.ID) == "" {
|
||||
return nil
|
||||
}
|
||||
if isReadyUploadFileStatus(result.Status) {
|
||||
return nil
|
||||
}
|
||||
|
||||
pollCtx, cancel := context.WithTimeout(ctx, fileReadyPollTimeout)
|
||||
defer cancel()
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < fileReadyPollAttempts; attempt++ {
|
||||
if err := pollCtx.Err(); err != nil {
|
||||
if lastErr != nil {
|
||||
return fmt.Errorf("waiting for file %s to become ready: %w", result.ID, lastErr)
|
||||
}
|
||||
return fmt.Errorf("waiting for file %s to become ready: %w", result.ID, err)
|
||||
}
|
||||
|
||||
fetched, err := c.fetchUploadedFile(pollCtx, a, result.ID)
|
||||
if err == nil && fetched != nil {
|
||||
mergeUploadFileResults(result, fetched)
|
||||
if isReadyUploadFileStatus(result.Status) {
|
||||
return nil
|
||||
}
|
||||
lastErr = fmt.Errorf("status=%s", strings.TrimSpace(result.Status))
|
||||
} else if err != nil {
|
||||
lastErr = err
|
||||
config.Logger.Debug("[upload_file] waiting for file readiness", "file_id", result.ID, "attempt", attempt+1, "error", err)
|
||||
}
|
||||
|
||||
if attempt < fileReadyPollAttempts-1 {
|
||||
fileReadySleep(fileReadyPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("status=%s", strings.TrimSpace(result.Status))
|
||||
}
|
||||
return fmt.Errorf("file %s did not become ready: %w", result.ID, lastErr)
|
||||
}
|
||||
|
||||
func (c *Client) fetchUploadedFile(ctx context.Context, a *auth.RequestAuth, fileID string) (*UploadFileResult, error) {
|
||||
fileID = strings.TrimSpace(fileID)
|
||||
if fileID == "" {
|
||||
return nil, errors.New("file id is required")
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
reqURL := DeepSeekFetchFilesURL + "?file_ids=" + url.QueryEscape(fileID)
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status != http.StatusOK || code != 0 || bizCode != 0 {
|
||||
if strings.TrimSpace(bizMsg) != "" {
|
||||
msg = bizMsg
|
||||
}
|
||||
if msg == "" {
|
||||
msg = http.StatusText(status)
|
||||
}
|
||||
return nil, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
result := extractFetchedUploadFileResult(resp, fileID)
|
||||
if result == nil || strings.TrimSpace(result.ID) == "" {
|
||||
return nil, errors.New("fetch files succeeded without matching file data")
|
||||
}
|
||||
result.Raw = resp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractFetchedUploadFileResult(resp map[string]any, targetID string) *UploadFileResult {
|
||||
targetID = strings.TrimSpace(targetID)
|
||||
if resp == nil || targetID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var walk func(any) *UploadFileResult
|
||||
walk = func(v any) *UploadFileResult {
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
if result := buildUploadFileResultFromMap(x, targetID); result != nil {
|
||||
return result
|
||||
}
|
||||
for _, nested := range x {
|
||||
if result := walk(nested); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if result := walk(item); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if result := walk(resp); result != nil {
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildUploadFileResultFromMap(m map[string]any, targetID string) *UploadFileResult {
|
||||
fileID := strings.TrimSpace(firstNonEmptyString(m, "id", "file_id"))
|
||||
if fileID == "" || !strings.EqualFold(fileID, targetID) {
|
||||
return nil
|
||||
}
|
||||
result := &UploadFileResult{
|
||||
ID: fileID,
|
||||
Filename: firstNonEmptyString(m, "name", "filename", "file_name"),
|
||||
Status: firstNonEmptyString(m, "status", "file_status"),
|
||||
Purpose: firstNonEmptyString(m, "purpose"),
|
||||
IsImage: firstBool(m, "is_image", "isImage"),
|
||||
Bytes: firstPositiveInt64(m, "bytes", "size", "file_size"),
|
||||
}
|
||||
if result.Status == "" {
|
||||
result.Status = "uploaded"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeUploadFileResults(dst, src *UploadFileResult) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(src.ID) != "" {
|
||||
dst.ID = strings.TrimSpace(src.ID)
|
||||
}
|
||||
if strings.TrimSpace(src.Filename) != "" {
|
||||
dst.Filename = strings.TrimSpace(src.Filename)
|
||||
}
|
||||
if src.Bytes > 0 {
|
||||
dst.Bytes = src.Bytes
|
||||
}
|
||||
if strings.TrimSpace(src.Status) != "" {
|
||||
dst.Status = strings.TrimSpace(src.Status)
|
||||
}
|
||||
if strings.TrimSpace(src.Purpose) != "" {
|
||||
dst.Purpose = strings.TrimSpace(src.Purpose)
|
||||
}
|
||||
dst.IsImage = src.IsImage
|
||||
if len(src.Raw) > 0 {
|
||||
dst.Raw = src.Raw
|
||||
}
|
||||
if src.RawHeaders != nil {
|
||||
dst.RawHeaders = src.RawHeaders.Clone()
|
||||
}
|
||||
}
|
||||
|
||||
func isReadyUploadFileStatus(status string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(status)) {
|
||||
case "processed", "ready", "done", "available", "success", "completed", "finished":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -35,6 +35,12 @@ func preview(b []byte) string {
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *Client) jsonHeaders(headers map[string]string) map[string]string {
|
||||
out := cloneStringMap(headers)
|
||||
out["Content-Type"] = "application/json"
|
||||
return out
|
||||
}
|
||||
|
||||
func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
|
||||
@@ -27,6 +27,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, fallba
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
headers = c.jsonHeaders(headers)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
282
internal/deepseek/client_upload.go
Normal file
282
internal/deepseek/client_upload.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
)
|
||||
|
||||
type UploadFileRequest struct {
|
||||
Filename string
|
||||
ContentType string
|
||||
Purpose string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type UploadFileResult struct {
|
||||
ID string
|
||||
Filename string
|
||||
Bytes int64
|
||||
Status string
|
||||
Purpose string
|
||||
AccountID string
|
||||
IsImage bool
|
||||
Raw map[string]any
|
||||
RawHeaders http.Header
|
||||
}
|
||||
|
||||
func (c *Client) UploadFile(ctx context.Context, a *auth.RequestAuth, req UploadFileRequest, maxAttempts int) (*UploadFileResult, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
if len(req.Data) == 0 {
|
||||
return nil, errors.New("file is required")
|
||||
}
|
||||
filename := strings.TrimSpace(req.Filename)
|
||||
if filename == "" {
|
||||
filename = "upload.bin"
|
||||
}
|
||||
contentType := strings.TrimSpace(req.ContentType)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
purpose := strings.TrimSpace(req.Purpose)
|
||||
body, contentTypeHeader, err := buildUploadMultipartBody(filename, contentType, req.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
capturePayload := map[string]any{
|
||||
"filename": filename,
|
||||
"content_type": contentType,
|
||||
"purpose": purpose,
|
||||
"bytes": len(req.Data),
|
||||
}
|
||||
captureSession := c.capture.Start("deepseek_upload_file", DeepSeekUploadFileURL, a.AccountID, capturePayload)
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
powHeader := ""
|
||||
for attempts < maxAttempts {
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
if strings.TrimSpace(powHeader) == "" {
|
||||
powHeader, err = c.GetPowForTarget(ctx, a, DeepSeekUploadTargetPath, maxAttempts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clients = c.requestClientsForAuth(ctx, a)
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["Content-Type"] = contentTypeHeader
|
||||
headers["x-ds-pow-response"] = powHeader
|
||||
headers["x-file-size"] = strconv.Itoa(len(req.Data))
|
||||
headers["x-thinking-enabled"] = "1"
|
||||
resp, err := c.doUpload(ctx, clients.regular, clients.fallback, DeepSeekUploadFileURL, headers, body)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[upload_file] request error", "error", err, "account", a.AccountID, "filename", filename)
|
||||
powHeader = ""
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
payloadBytes, readErr := readResponseBody(resp)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
powHeader = ""
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
parsed := map[string]any{}
|
||||
if len(payloadBytes) > 0 {
|
||||
if err := json.Unmarshal(payloadBytes, &parsed); err != nil {
|
||||
config.Logger.Warn("[upload_file] json parse failed", "status", resp.StatusCode, "preview", preview(payloadBytes))
|
||||
}
|
||||
}
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(parsed)
|
||||
if resp.StatusCode == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
result := extractUploadFileResult(parsed)
|
||||
result.Raw = parsed
|
||||
result.RawHeaders = resp.Header.Clone()
|
||||
if result.Filename == "" {
|
||||
result.Filename = filename
|
||||
}
|
||||
if result.Bytes == 0 {
|
||||
result.Bytes = int64(len(req.Data))
|
||||
}
|
||||
if result.Purpose == "" {
|
||||
result.Purpose = purpose
|
||||
}
|
||||
if result.AccountID == "" {
|
||||
result.AccountID = a.AccountID
|
||||
}
|
||||
if result.ID == "" {
|
||||
return nil, errors.New("upload file succeeded without file id")
|
||||
}
|
||||
if err := c.waitForUploadedFile(ctx, a, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
config.Logger.Warn("[upload_file] failed", "status", resp.StatusCode, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "account", a.AccountID, "filename", filename)
|
||||
powHeader = ""
|
||||
if a.UseConfigToken {
|
||||
if !refreshed && shouldAttemptRefresh(resp.StatusCode, code, bizCode, msg, bizMsg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return nil, errors.New("upload file failed")
|
||||
}
|
||||
|
||||
func buildUploadMultipartBody(filename, contentType string, data []byte) ([]byte, string, error) {
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
partHeader := textproto.MIMEHeader{}
|
||||
partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename=%q`, escapeMultipartFilename(filename)))
|
||||
partHeader.Set("Content-Type", contentType)
|
||||
part, err := writer.CreatePart(partHeader)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return buf.Bytes(), writer.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
func escapeMultipartFilename(filename string) string {
|
||||
filename = filepath.Base(strings.TrimSpace(filename))
|
||||
filename = strings.ReplaceAll(filename, `\`, "_")
|
||||
filename = strings.ReplaceAll(filename, `"`, "_")
|
||||
if filename == "." || filename == "" {
|
||||
return "upload.bin"
|
||||
}
|
||||
return filename
|
||||
}
|
||||
|
||||
func (c *Client) doUpload(ctx context.Context, doer trans.Doer, fallback trans.Doer, url string, headers map[string]string, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := doer.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
config.Logger.Warn("[deepseek] fingerprint upload request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if reqErr != nil {
|
||||
return nil, reqErr
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
return fallback.Do(req2)
|
||||
}
|
||||
|
||||
func extractUploadFileResult(resp map[string]any) *UploadFileResult {
|
||||
result := &UploadFileResult{Status: "uploaded"}
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
searchMaps := []map[string]any{resp, data, bizData}
|
||||
for _, parent := range []map[string]any{resp, data, bizData} {
|
||||
if parent == nil {
|
||||
continue
|
||||
}
|
||||
for _, key := range []string{"file", "biz_data", "data"} {
|
||||
if nested, ok := parent[key].(map[string]any); ok {
|
||||
searchMaps = append(searchMaps, nested)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range searchMaps {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if result.ID == "" {
|
||||
result.ID = firstNonEmptyString(m, "id", "file_id")
|
||||
}
|
||||
if result.Filename == "" {
|
||||
result.Filename = firstNonEmptyString(m, "name", "filename", "file_name")
|
||||
}
|
||||
if result.Status == "uploaded" {
|
||||
if status := firstNonEmptyString(m, "status", "file_status"); status != "" {
|
||||
result.Status = status
|
||||
}
|
||||
}
|
||||
if !result.IsImage {
|
||||
result.IsImage = firstBool(m, "is_image", "isImage")
|
||||
}
|
||||
if result.Purpose == "" {
|
||||
result.Purpose = firstNonEmptyString(m, "purpose")
|
||||
}
|
||||
if result.AccountID == "" {
|
||||
result.AccountID = firstNonEmptyString(m, "account_id", "accountId", "owner_account_id", "ownerAccountId")
|
||||
}
|
||||
if result.Bytes == 0 {
|
||||
result.Bytes = firstPositiveInt64(m, "bytes", "size", "file_size")
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func firstBool(m map[string]any, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
switch v := m[key].(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "true", "1", "yes", "y":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func firstNonEmptyString(m map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if v, _ := m[key].(string); strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstPositiveInt64(m map[string]any, keys ...string) int64 {
|
||||
for _, key := range keys {
|
||||
if v := toInt64(m[key], 0); v > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
216
internal/deepseek/client_upload_test.go
Normal file
216
internal/deepseek/client_upload_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
powpkg "ds2api/pow"
|
||||
)
|
||||
|
||||
func TestBuildUploadMultipartBodyOmitsPurposeAndIncludesFilePart(t *testing.T) {
|
||||
body, contentType, err := buildUploadMultipartBody(`../demo.txt`, "text/plain", []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("buildUploadMultipartBody error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(contentType, "multipart/form-data; boundary=") {
|
||||
t.Fatalf("unexpected content type: %q", contentType)
|
||||
}
|
||||
payload := string(body)
|
||||
if strings.Contains(payload, `name="purpose"`) || strings.Contains(payload, "assistants") {
|
||||
t.Fatalf("expected purpose to be omitted from payload: %q", payload)
|
||||
}
|
||||
if !strings.Contains(payload, `name="file"; filename="demo.txt"`) {
|
||||
t.Fatalf("expected sanitized filename in payload: %q", payload)
|
||||
}
|
||||
if !strings.Contains(payload, "Content-Type: text/plain") {
|
||||
t.Fatalf("expected file content type in payload: %q", payload)
|
||||
}
|
||||
if !strings.Contains(payload, "hello") {
|
||||
t.Fatalf("expected file content in payload: %q", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractUploadFileResultSupportsNestedShapes(t *testing.T) {
|
||||
got := extractUploadFileResult(map[string]any{
|
||||
"data": map[string]any{
|
||||
"biz_data": map[string]any{
|
||||
"file": map[string]any{
|
||||
"file_id": "file_123",
|
||||
"file_name": "report.pdf",
|
||||
"file_size": 99,
|
||||
"status": "processed",
|
||||
"purpose": "assistants",
|
||||
"is_image": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if got.ID != "file_123" {
|
||||
t.Fatalf("expected id file_123, got %#v", got)
|
||||
}
|
||||
if got.Filename != "report.pdf" {
|
||||
t.Fatalf("expected filename report.pdf, got %#v", got)
|
||||
}
|
||||
if got.Bytes != 99 {
|
||||
t.Fatalf("expected bytes 99, got %#v", got)
|
||||
}
|
||||
if got.Status != "processed" {
|
||||
t.Fatalf("expected status processed, got %#v", got)
|
||||
}
|
||||
if got.Purpose != "assistants" {
|
||||
t.Fatalf("expected purpose assistants, got %#v", got)
|
||||
}
|
||||
if !got.IsImage {
|
||||
t.Fatalf("expected image flag true, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadFileUsesUploadTargetPowAndMultipartHeaders(t *testing.T) {
|
||||
challengeHash := powpkg.DeepSeekHashV1([]byte(powpkg.BuildPrefix("salt", 1712345678) + "42"))
|
||||
powResponse := `{"code":0,"msg":"ok","data":{"biz_code":0,"biz_data":{"challenge":{"algorithm":"DeepSeekHashV1","challenge":"` + hex.EncodeToString(challengeHash[:]) + `","salt":"salt","expire_at":1712345678,"difficulty":1000,"signature":"sig","target_path":"` + DeepSeekUploadTargetPath + `"}}}}`
|
||||
uploadResponse := `{"code":0,"msg":"ok","data":{"biz_code":0,"biz_data":{"file":{"file_id":"file_789","filename":"demo.txt","bytes":5,"status":"processed","purpose":"assistants","is_image":false}}}}`
|
||||
var seenPow string
|
||||
var seenTargetPath string
|
||||
var seenContentType string
|
||||
var seenFileSize string
|
||||
var seenBody string
|
||||
call := 0
|
||||
client := &Client{
|
||||
regular: doerFunc(func(req *http.Request) (*http.Response, error) {
|
||||
call++
|
||||
bodyBytes, _ := io.ReadAll(req.Body)
|
||||
switch call {
|
||||
case 1:
|
||||
seenTargetPath = string(bodyBytes)
|
||||
return &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(powResponse)), Request: req}, nil
|
||||
case 2:
|
||||
seenPow = req.Header.Get("x-ds-pow-response")
|
||||
seenContentType = req.Header.Get("Content-Type")
|
||||
seenFileSize = req.Header.Get("x-file-size")
|
||||
seenBody = string(bodyBytes)
|
||||
return &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(uploadResponse)), Request: req}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected request count %d", call)
|
||||
return nil, nil
|
||||
}
|
||||
}),
|
||||
fallback: &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
})},
|
||||
maxRetries: 1,
|
||||
}
|
||||
result, err := client.UploadFile(context.Background(), &auth.RequestAuth{DeepSeekToken: "token", TriedAccounts: map[string]bool{}}, UploadFileRequest{
|
||||
Filename: "demo.txt",
|
||||
ContentType: "text/plain",
|
||||
Purpose: "assistants",
|
||||
Data: []byte("hello"),
|
||||
}, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("UploadFile error: %v", err)
|
||||
}
|
||||
if result.ID != "file_789" {
|
||||
t.Fatalf("expected uploaded file id file_789, got %#v", result)
|
||||
}
|
||||
if !strings.Contains(seenTargetPath, `"target_path":"`+DeepSeekUploadTargetPath+`"`) {
|
||||
t.Fatalf("expected upload target_path in pow request, got %q", seenTargetPath)
|
||||
}
|
||||
if strings.TrimSpace(seenPow) == "" {
|
||||
t.Fatal("expected x-ds-pow-response header")
|
||||
}
|
||||
rawPow, err := base64.StdEncoding.DecodeString(seenPow)
|
||||
if err != nil {
|
||||
t.Fatalf("decode pow header failed: %v", err)
|
||||
}
|
||||
var powHeader map[string]any
|
||||
if err := json.Unmarshal(rawPow, &powHeader); err != nil {
|
||||
t.Fatalf("unmarshal pow header failed: %v", err)
|
||||
}
|
||||
if powHeader["target_path"] != DeepSeekUploadTargetPath {
|
||||
t.Fatalf("expected pow target_path %q, got %#v", DeepSeekUploadTargetPath, powHeader["target_path"])
|
||||
}
|
||||
if seenFileSize != "5" {
|
||||
t.Fatalf("expected x-file-size=5, got %q", seenFileSize)
|
||||
}
|
||||
if !strings.HasPrefix(seenContentType, "multipart/form-data; boundary=") {
|
||||
t.Fatalf("expected multipart content type, got %q", seenContentType)
|
||||
}
|
||||
if !strings.Contains(seenBody, `name="file"; filename="demo.txt"`) {
|
||||
t.Fatalf("expected file part in upload body: %q", seenBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadFileWaitsForProcessedFetchFiles(t *testing.T) {
|
||||
oldSleep := fileReadySleep
|
||||
fileReadySleep = func(time.Duration) {}
|
||||
defer func() { fileReadySleep = oldSleep }()
|
||||
|
||||
challengeHash := powpkg.DeepSeekHashV1([]byte(powpkg.BuildPrefix("salt", 1712345678) + "42"))
|
||||
powResponse := `{"code":0,"msg":"ok","data":{"biz_code":0,"biz_data":{"challenge":{"algorithm":"DeepSeekHashV1","challenge":"` + hex.EncodeToString(challengeHash[:]) + `","salt":"salt","expire_at":1712345678,"difficulty":1000,"signature":"sig","target_path":"` + DeepSeekUploadTargetPath + `"}}}}`
|
||||
uploadResponse := `{"code":0,"msg":"ok","data":{"biz_code":0,"biz_data":{"file":{"file_id":"file_789","filename":"demo.txt","bytes":5,"status":"PENDING","purpose":"assistants","is_image":false}}}}`
|
||||
pendingFetchResponse := `{"code":0,"msg":"ok","data":{"biz_code":0,"biz_data":{"files":[{"file_id":"file_789","filename":"demo.txt","bytes":5,"status":"PENDING","purpose":"assistants","is_image":false}]}}}`
|
||||
processedFetchResponse := `{"code":0,"msg":"ok","data":{"biz_code":0,"biz_data":{"files":[{"file_id":"file_789","filename":"demo.txt","bytes":5,"status":"processed","purpose":"assistants","is_image":true}]}}}`
|
||||
|
||||
var call int
|
||||
client := &Client{
|
||||
regular: doerFunc(func(req *http.Request) (*http.Response, error) {
|
||||
call++
|
||||
switch call {
|
||||
case 1:
|
||||
bodyBytes, _ := io.ReadAll(req.Body)
|
||||
if !strings.Contains(string(bodyBytes), `"target_path":"`+DeepSeekUploadTargetPath+`"`) {
|
||||
t.Fatalf("expected pow target path request, got %s", string(bodyBytes))
|
||||
}
|
||||
return &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(powResponse)), Request: req}, nil
|
||||
case 2:
|
||||
return &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(uploadResponse)), Request: req}, nil
|
||||
case 3, 4:
|
||||
if req.Method != http.MethodGet {
|
||||
t.Fatalf("expected GET fetch request, got %s", req.Method)
|
||||
}
|
||||
if req.URL.Path != "/api/v0/file/fetch_files" {
|
||||
t.Fatalf("expected fetch files path /api/v0/file/fetch_files, got %q", req.URL.Path)
|
||||
}
|
||||
if got := req.URL.Query().Get("file_ids"); got != "file_789" {
|
||||
t.Fatalf("expected file_ids=file_789, got %q", got)
|
||||
}
|
||||
respBody := pendingFetchResponse
|
||||
if call == 4 {
|
||||
respBody = processedFetchResponse
|
||||
}
|
||||
return &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(respBody)), Request: req}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected request count %d", call)
|
||||
return nil, nil
|
||||
}
|
||||
}),
|
||||
fallback: &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { return nil, nil })},
|
||||
maxRetries: 1,
|
||||
}
|
||||
|
||||
result, err := client.UploadFile(context.Background(), &auth.RequestAuth{DeepSeekToken: "token", TriedAccounts: map[string]bool{}}, UploadFileRequest{
|
||||
Filename: "demo.txt",
|
||||
ContentType: "text/plain",
|
||||
Purpose: "assistants",
|
||||
Data: []byte("hello"),
|
||||
}, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("UploadFile error: %v", err)
|
||||
}
|
||||
if result.ID != "file_789" {
|
||||
t.Fatalf("expected uploaded file id file_789, got %#v", result)
|
||||
}
|
||||
if result.Status != "processed" {
|
||||
t.Fatalf("expected final status processed, got %#v", result.Status)
|
||||
}
|
||||
if call != 4 {
|
||||
t.Fatalf("expected 4 requests, got %d", call)
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,13 @@ const (
|
||||
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||
DeepSeekContinueURL = "https://chat.deepseek.com/api/v0/chat/continue"
|
||||
DeepSeekUploadFileURL = "https://chat.deepseek.com/api/v0/file/upload_file"
|
||||
DeepSeekFetchFilesURL = "https://chat.deepseek.com/api/v0/file/fetch_files"
|
||||
DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page"
|
||||
DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete"
|
||||
DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"
|
||||
DeepSeekCompletionTargetPath = "/api/v0/chat/completion"
|
||||
DeepSeekUploadTargetPath = "/api/v0/file/upload_file"
|
||||
)
|
||||
|
||||
var defaultBaseHeaders = map[string]string{
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
"Host": "chat.deepseek.com",
|
||||
"User-Agent": "DeepSeek/1.8.0 Android/35",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"x-client-platform": "android",
|
||||
"x-client-version": "1.8.0",
|
||||
"x-client-locale": "zh_CN",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user