mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
Compare commits
78 Commits
v2.1.9
...
v2.3.5_bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1f8d6192f | ||
|
|
acfb3b225d | ||
|
|
99a6164000 | ||
|
|
e49d9d33e2 | ||
|
|
184a3d1e4e | ||
|
|
c4ec14f49a | ||
|
|
fb5fc0e885 | ||
|
|
20b603666d | ||
|
|
4d549b7102 | ||
|
|
33b0d1d144 | ||
|
|
efb484ba4f | ||
|
|
145501d4a5 | ||
|
|
2d5103997b | ||
|
|
52e7e7aae8 | ||
|
|
5b5a4000d7 | ||
|
|
2bbf603148 | ||
|
|
d14b8a0664 | ||
|
|
f16e0b579e | ||
|
|
cf569f4749 | ||
|
|
c9c59f2490 | ||
|
|
16216cc2ca | ||
|
|
de50fd3954 | ||
|
|
7648d5f192 | ||
|
|
d35e5eab25 | ||
|
|
90610a52ce | ||
|
|
f6296d506f | ||
|
|
dfea092583 | ||
|
|
af7dc134bb | ||
|
|
2657d37f76 | ||
|
|
7318d1f4a8 | ||
|
|
f2674487c7 | ||
|
|
71cdcb43e8 | ||
|
|
9c46c3a874 | ||
|
|
12d5f136d5 | ||
|
|
00c37d8d2f | ||
|
|
0f1985af4a | ||
|
|
fa8affe1b7 | ||
|
|
c59a0b7799 | ||
|
|
bd72b91f27 | ||
|
|
9240f85246 | ||
|
|
ea4bd1e483 | ||
|
|
9e0de62707 | ||
|
|
128de290db | ||
|
|
286d266723 | ||
|
|
8aad1005b2 | ||
|
|
11b2f24fc2 | ||
|
|
d1f08cbb89 | ||
|
|
60e9d707d4 | ||
|
|
9b93badb57 | ||
|
|
892213071a | ||
|
|
5484d6e59d | ||
|
|
0ce3fd22a7 | ||
|
|
25e40cc3a6 | ||
|
|
af68d21095 | ||
|
|
1fafd25e86 | ||
|
|
5f8f28a943 | ||
|
|
94cf1bfcc7 | ||
|
|
13562cf521 | ||
|
|
d27e700c4f | ||
|
|
d6bce5af93 | ||
|
|
75969e710d | ||
|
|
6c39c8e191 | ||
|
|
0e261ff0a0 | ||
|
|
fab326eca1 | ||
|
|
c033eceee7 | ||
|
|
a10e03ebe0 | ||
|
|
a6aa4a1839 | ||
|
|
1c749b6803 | ||
|
|
c329bf26b6 | ||
|
|
3ae5b57ebe | ||
|
|
0bf5d5440c | ||
|
|
d731a1fd4f | ||
|
|
93e9fb531d | ||
|
|
6daeb2553d | ||
|
|
321b8a89ee | ||
|
|
d84875e466 | ||
|
|
ea8c9a28a9 | ||
|
|
a302fb3c25 |
7
API.md
7
API.md
@@ -284,6 +284,11 @@ data: [DONE]
|
|||||||
|
|
||||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||||
|
|
||||||
|
补充说明:
|
||||||
|
|
||||||
|
- **非代码块上下文**下,工具 JSON 即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||||
|
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### `GET /v1/models/{id}`
|
### `GET /v1/models/{id}`
|
||||||
@@ -301,7 +306,7 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
|||||||
| `messages` | array | ❌ | 与 `input` 二选一 |
|
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||||
| `stream` | boolean | ❌ | 默认 `false` |
|
| `stream` | boolean | ❌ | 默认 `false` |
|
||||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略(含代码块示例豁免) |
|
||||||
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||||
|
|
||||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||||
|
|||||||
28
README.MD
28
README.MD
@@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/ds2api-icon.svg" width="128" height="128" alt="DS2API icon" />
|
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# DS2API
|
# DS2API
|
||||||
@@ -10,6 +10,7 @@
|
|||||||
[](https://github.com/CJackHwang/ds2api/releases)
|
[](https://github.com/CJackHwang/ds2api/releases)
|
||||||
[](DEPLOY.md)
|
[](DEPLOY.md)
|
||||||
[](https://zeabur.com/templates/L4CFHP)
|
[](https://zeabur.com/templates/L4CFHP)
|
||||||
|
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||||
|
|
||||||
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
||||||
|
|
||||||
@@ -105,6 +106,14 @@ flowchart LR
|
|||||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||||
|
|
||||||
|
|
||||||
|
#### Claude Code 接入避坑(实测)
|
||||||
|
|
||||||
|
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||||
|
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||||
|
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||||
|
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||||
|
|
||||||
### Gemini 接口
|
### Gemini 接口
|
||||||
|
|
||||||
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||||
@@ -467,6 +476,23 @@ go run ./cmd/ds2api-tests \
|
|||||||
npm ci --prefix webui && npm run build --prefix webui
|
npm ci --prefix webui && npm run build --prefix webui
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
详细测试指南请参阅 [TESTING.md](TESTING.md)。
|
||||||
|
|
||||||
|
### 快速测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行所有单元测试
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
# 运行 tool calls 相关测试(调试工具调用问题)
|
||||||
|
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||||
|
|
||||||
|
# 运行端到端测试
|
||||||
|
./tests/scripts/run-live.sh
|
||||||
|
```
|
||||||
|
|
||||||
## Release 自动构建(GitHub Actions)
|
## Release 自动构建(GitHub Actions)
|
||||||
|
|
||||||
工作流文件:`.github/workflows/release-artifacts.yml`
|
工作流文件:`.github/workflows/release-artifacts.yml`
|
||||||
|
|||||||
12
README.en.md
12
README.en.md
@@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/ds2api-icon.svg" width="128" height="128" alt="DS2API icon" />
|
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# DS2API
|
# DS2API
|
||||||
@@ -10,6 +10,7 @@
|
|||||||
[](https://github.com/CJackHwang/ds2api/releases)
|
[](https://github.com/CJackHwang/ds2api/releases)
|
||||||
[](DEPLOY.en.md)
|
[](DEPLOY.en.md)
|
||||||
[](https://zeabur.com/templates/L4CFHP)
|
[](https://zeabur.com/templates/L4CFHP)
|
||||||
|
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||||
|
|
||||||
Language: [中文](README.MD) | [English](README.en.md)
|
Language: [中文](README.MD) | [English](README.en.md)
|
||||||
|
|
||||||
@@ -105,6 +106,14 @@ flowchart LR
|
|||||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||||
|
|
||||||
|
|
||||||
|
#### Claude Code integration pitfalls (validated)
|
||||||
|
|
||||||
|
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||||
|
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||||
|
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||||
|
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||||
|
|
||||||
### Gemini Endpoint
|
### Gemini Endpoint
|
||||||
|
|
||||||
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
||||||
@@ -350,6 +359,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
|||||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||||
|
|
||||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||||
|
- In non-code-block context, tool JSON may still be recognized even when mixed with normal prose; surrounding prose can remain as text output.
|
||||||
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||||
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
||||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||||
|
|||||||
46
TESTING.md
46
TESTING.md
@@ -51,7 +51,7 @@ DS2API 提供两个层级的测试:
|
|||||||
1. **Preflight 检查**:
|
1. **Preflight 检查**:
|
||||||
- `go test ./... -count=1`(单元测试)
|
- `go test ./... -count=1`(单元测试)
|
||||||
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测)
|
- `node --test`(如仓库存在 Node 单测文件时执行;当前默认以 Go 测试 + Node 语法门禁为主)
|
||||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||||
|
|
||||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||||
@@ -173,6 +173,50 @@ rg "<trace_id>" artifacts/testsuite/<run_id>/server.log
|
|||||||
go test ./...
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 运行特定模块的单元测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行 tool calls 相关测试(推荐用于调试 tool call 解析问题)
|
||||||
|
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||||
|
|
||||||
|
# 运行单个测试用例
|
||||||
|
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/
|
||||||
|
|
||||||
|
# 运行 format 相关测试
|
||||||
|
go test -v ./internal/format/...
|
||||||
|
|
||||||
|
# 运行 adapter 相关测试
|
||||||
|
go test -v ./internal/adapter/openai/...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 调试 Tool Call 问题 | Debugging Tool Call Issues
|
||||||
|
|
||||||
|
当遇到 DeepSeek 工具调用解析问题时,可以使用以下方法:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 运行 tool calls 相关的所有测试
|
||||||
|
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||||
|
|
||||||
|
# 2. 查看测试输出中的详细调试信息
|
||||||
|
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ 2>&1
|
||||||
|
|
||||||
|
# 3. 检查具体测试用例的修复效果
|
||||||
|
# 测试用例位于 internal/util/toolcalls_test.go,包含:
|
||||||
|
# - TestParseToolCallsWithDeepSeekHallucination: DeepSeek 典型幻觉输出
|
||||||
|
# - TestRepairLooseJSONWithNestedObjects: 嵌套对象的方括号修复
|
||||||
|
# - TestParseToolCallsWithMixedWindowsPaths: Windows 路径处理
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行 Node.js 测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行 Node 测试
|
||||||
|
node --test tests/node/stream-tool-sieve.test.js
|
||||||
|
|
||||||
|
# 或使用脚本
|
||||||
|
./tests/scripts/run-unit-node.sh
|
||||||
|
```
|
||||||
|
|
||||||
### 跑端到端测试(跳过 preflight)
|
### 跑端到端测试(跳过 preflight)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="DS2API icon">
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="bg" x1="96" y1="96" x2="416" y2="416" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop offset="0" stop-color="#06162D" />
|
|
||||||
<stop offset="0.6" stop-color="#0A3A6A" />
|
|
||||||
<stop offset="1" stop-color="#00B4D8" />
|
|
||||||
</linearGradient>
|
|
||||||
<radialGradient id="glow" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(256 180) rotate(90) scale(260)">
|
|
||||||
<stop offset="0" stop-color="#FFFFFF" stop-opacity="0.18" />
|
|
||||||
<stop offset="1" stop-color="#FFFFFF" stop-opacity="0" />
|
|
||||||
</radialGradient>
|
|
||||||
<linearGradient id="whale" x1="180" y1="140" x2="360" y2="360" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop offset="0" stop-color="#EAF7FF" />
|
|
||||||
<stop offset="1" stop-color="#BDEBFF" />
|
|
||||||
</linearGradient>
|
|
||||||
</defs>
|
|
||||||
|
|
||||||
<circle cx="256" cy="256" r="240" fill="url(#bg)" />
|
|
||||||
<circle cx="256" cy="256" r="240" fill="url(#glow)" />
|
|
||||||
<circle cx="256" cy="256" r="240" stroke="#FFFFFF" stroke-opacity="0.14" stroke-width="8" />
|
|
||||||
|
|
||||||
<!-- subtle waves -->
|
|
||||||
<path d="M104 338 C156 308 204 366 256 334 C308 302 356 360 408 330" stroke="#FFFFFF" stroke-opacity="0.16" stroke-width="12" stroke-linecap="round" />
|
|
||||||
<path d="M124 372 C174 344 212 396 256 372 C300 348 338 396 388 368" stroke="#FFFFFF" stroke-opacity="0.12" stroke-width="10" stroke-linecap="round" />
|
|
||||||
|
|
||||||
<!-- whale tail (DeepSeek-inspired element, original design) -->
|
|
||||||
<path
|
|
||||||
d="M256 162
|
|
||||||
C228 124 184 118 156 146
|
|
||||||
C132 170 138 206 162 230
|
|
||||||
C190 262 230 252 252 220
|
|
||||||
C254 218 255 216 256 214
|
|
||||||
C257 216 258 218 260 220
|
|
||||||
C282 252 322 262 350 230
|
|
||||||
C374 206 380 170 356 146
|
|
||||||
C328 118 284 124 256 162 Z"
|
|
||||||
fill="url(#whale)"
|
|
||||||
/>
|
|
||||||
<rect x="236" y="214" width="40" height="168" rx="20" fill="url(#whale)" />
|
|
||||||
|
|
||||||
<!-- API nodes -->
|
|
||||||
<g opacity="0.55" stroke="#FFFFFF" stroke-opacity="0.35" stroke-width="6" stroke-linecap="round">
|
|
||||||
<path d="M156 236 L208 206" />
|
|
||||||
<path d="M356 236 L304 206" />
|
|
||||||
<path d="M208 206 L232 172" />
|
|
||||||
<circle cx="156" cy="236" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="208" cy="206" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="232" cy="172" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="304" cy="206" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="356" cy="236" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
</g>
|
|
||||||
|
|
||||||
<!-- tiny sparkle -->
|
|
||||||
<path
|
|
||||||
d="M378 164
|
|
||||||
C372 170 366 174 358 176
|
|
||||||
C366 178 372 182 378 188
|
|
||||||
C380 180 384 176 392 176
|
|
||||||
C384 174 380 170 378 164 Z"
|
|
||||||
fill="#FFFFFF"
|
|
||||||
fill-opacity="0.32"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 2.7 KiB |
@@ -1,6 +1,6 @@
|
|||||||
services:
|
services:
|
||||||
ds2api:
|
ds2api:
|
||||||
image: crpi-cnazxqmg4avmg4fq.cn-beijing.personal.cr.aliyuncs.com/ronghuaxueleng/ds2api:latest
|
image: ghcr.io/cjackhwang/ds2api:latest
|
||||||
container_name: ds2api
|
container_name: ds2api
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
41
docs/toolcall-semantics.md
Normal file
41
docs/toolcall-semantics.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Tool call parsing semantics (Go canonical spec)
|
||||||
|
|
||||||
|
This document defines the cross-runtime contract for `ParseToolCallsDetailed` / `parseToolCallsDetailed`.
|
||||||
|
|
||||||
|
## Output contract
|
||||||
|
|
||||||
|
- `calls`: accepted tool calls with normalized tool names.
|
||||||
|
- `sawToolCallSyntax`: true when tool-call-like syntax is detected (`tool_calls`, `<tool_call>`, `<function_call>`, `<invoke>`) or a valid call is parsed.
|
||||||
|
- `rejectedByPolicy`: true when parser extracted call syntax but all calls are rejected by allow-list policy.
|
||||||
|
- `rejectedToolNames`: de-duplicated rejected tool names in first-seen order.
|
||||||
|
|
||||||
|
## Parse pipeline
|
||||||
|
|
||||||
|
1. Strip fenced code blocks for non-standalone parsing.
|
||||||
|
2. Build candidates from:
|
||||||
|
- full text,
|
||||||
|
- fenced JSON snippets,
|
||||||
|
- extracted JSON objects around `tool_calls`,
|
||||||
|
- first `{` to last `}` object slice.
|
||||||
|
3. Parse each candidate in order:
|
||||||
|
- JSON payload parser (`tool_calls`, list, single call object),
|
||||||
|
- XML/Markup parser (`<tool_call>`, `<function_call>`, `<invoke>`; supports attributes + nested fields),
|
||||||
|
- Text KV fallback parser (`function.name: <name>` ... `function.arguments: {json}`).
|
||||||
|
4. Stop at first candidate that yields at least one call.
|
||||||
|
|
||||||
|
## Name normalization policy
|
||||||
|
|
||||||
|
When matching parsed names against configured tools:
|
||||||
|
|
||||||
|
1. exact match,
|
||||||
|
2. case-insensitive match,
|
||||||
|
3. namespace tail match (`a.b.c` => `c`),
|
||||||
|
4. loose alnum match (remove non `[a-z0-9]`, compare).
|
||||||
|
|
||||||
|
## Standalone mode
|
||||||
|
|
||||||
|
Standalone mode (`ParseStandaloneToolCallsDetailed`) parses the whole input directly (no candidate slicing), while still applying:
|
||||||
|
|
||||||
|
- example-context guard,
|
||||||
|
- JSON then markup fallback,
|
||||||
|
- the same allow-list normalization policy.
|
||||||
@@ -315,3 +315,83 @@ func asString(v any) string {
|
|||||||
s, _ := v.(string)
|
s, _ := v.(string)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload string
|
||||||
|
}{
|
||||||
|
{name: "xml_tool_call", payload: `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`},
|
||||||
|
{name: "xml_json_tool_call", payload: `<tool_call>{"tool":"Bash","params":{"command":"pwd"}}</tool_call>`},
|
||||||
|
{name: "nested_tool_tag_style", payload: `<tool_call><tool name="Bash"><command>pwd</command></tool></tool_call>`},
|
||||||
|
{name: "function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`},
|
||||||
|
{name: "antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`},
|
||||||
|
{name: "antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`},
|
||||||
|
{name: "invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeClaudeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||||
|
|
||||||
|
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||||
|
|
||||||
|
frames := parseClaudeFrames(t, rec.Body.String())
|
||||||
|
foundToolUse := false
|
||||||
|
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||||
|
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||||
|
if contentBlock["type"] == "tool_use" {
|
||||||
|
foundToolUse = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundToolUse {
|
||||||
|
t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeClaudeSSEHTTPResponse(
|
||||||
|
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||||
|
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||||
|
|
||||||
|
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||||
|
|
||||||
|
frames := parseClaudeFrames(t, rec.Body.String())
|
||||||
|
foundToolUse := false
|
||||||
|
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||||
|
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||||
|
if contentBlock["type"] == "tool_use" {
|
||||||
|
foundToolUse = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundToolUse {
|
||||||
|
t.Fatalf("expected tool_use for fenced example, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
foundToolStop := false
|
||||||
|
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||||
|
delta, _ := f.Payload["delta"].(map[string]any)
|
||||||
|
if delta["stop_reason"] == "tool_use" {
|
||||||
|
foundToolStop = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundToolStop {
|
||||||
|
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -125,8 +125,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
|||||||
if !containsStr(prompt, "Search the web") {
|
if !containsStr(prompt, "Search the web") {
|
||||||
t.Fatalf("expected description in prompt")
|
t.Fatalf("expected description in prompt")
|
||||||
}
|
}
|
||||||
if !containsStr(prompt, "tool_calls") {
|
if !containsStr(prompt, "tool_use") {
|
||||||
t.Fatalf("expected tool_calls instruction in prompt")
|
t.Fatalf("expected tool_use instruction in prompt")
|
||||||
|
}
|
||||||
|
if containsStr(prompt, "tool_calls") {
|
||||||
|
t.Fatalf("expected prompt to avoid tool_calls JSON instruction")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func buildClaudeToolPrompt(tools []any) string {
|
|||||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||||
}
|
}
|
||||||
parts = append(parts,
|
parts = append(parts,
|
||||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
"When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
|
||||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
streamengine "ds2api/internal/stream"
|
streamengine "ds2api/internal/stream"
|
||||||
|
"ds2api/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeStreamRuntime struct {
|
type claudeStreamRuntime struct {
|
||||||
@@ -116,6 +117,18 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
|||||||
|
|
||||||
s.text.WriteString(p.Text)
|
s.text.WriteString(p.Text)
|
||||||
if s.bufferToolContent {
|
if s.bufferToolContent {
|
||||||
|
if hasUnclosedCodeFence(s.text.String()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
|
||||||
|
if len(detected) > 0 {
|
||||||
|
s.finalize("tool_use")
|
||||||
|
return streamengine.ParsedDecision{
|
||||||
|
ContentSeen: true,
|
||||||
|
Stop: true,
|
||||||
|
StopReason: streamengine.StopReason("tool_use_detected"),
|
||||||
|
}
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.closeThinkingBlock()
|
s.closeThinkingBlock()
|
||||||
@@ -144,3 +157,7 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
|||||||
|
|
||||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasUnclosedCodeFence(text string) bool {
|
||||||
|
return strings.Count(text, "```")%2 == 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
|||||||
|
|
||||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||||
upstream := makeGeminiUpstreamResponse(
|
upstream := makeGeminiUpstreamResponse(
|
||||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||||
`data: [DONE]`,
|
`data: [DONE]`,
|
||||||
)
|
)
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||||
|
upstream := makeGeminiUpstreamResponse(
|
||||||
|
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
body := `{
|
||||||
|
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||||
|
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||||
|
}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
candidates, _ := out["candidates"].([]any)
|
||||||
|
c0, _ := candidates[0].(map[string]any)
|
||||||
|
content, _ := c0["content"].(map[string]any)
|
||||||
|
parts, _ := content["parts"].([]any)
|
||||||
|
part0, _ := parts[0].(map[string]any)
|
||||||
|
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||||
|
if functionCall["name"] != "eval_javascript" {
|
||||||
|
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||||
upstream := makeGeminiUpstreamResponse(
|
upstream := makeGeminiUpstreamResponse(
|
||||||
`data: {"p":"response/content","v":"hello "}`,
|
`data: {"p":"response/content","v":"hello "}`,
|
||||||
|
|||||||
@@ -98,11 +98,11 @@ func (s *chatStreamRuntime) sendDone() {
|
|||||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||||
finalThinking := s.thinking.String()
|
finalThinking := s.thinking.String()
|
||||||
finalText := s.text.String()
|
finalText := s.text.String()
|
||||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
delta := map[string]any{
|
delta := map[string]any{
|
||||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
||||||
}
|
}
|
||||||
if !s.firstChunkSent {
|
if !s.firstChunkSent {
|
||||||
delta["role"] = "assistant"
|
delta["role"] = "assistant"
|
||||||
@@ -158,7 +158,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(detected) > 0 || s.toolCallsEmitted {
|
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
}
|
}
|
||||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type DeepSeekCaller interface {
|
|||||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||||
|
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigReader interface {
|
type ConfigReader interface {
|
||||||
@@ -28,6 +29,7 @@ type ConfigReader interface {
|
|||||||
ToolcallEarlyEmitConfidence() string
|
ToolcallEarlyEmitConfidence() string
|
||||||
ResponsesStoreTTLSeconds() int
|
ResponsesStoreTTLSeconds() int
|
||||||
EmbeddingsProvider() string
|
EmbeddingsProvider() string
|
||||||
|
AutoDeleteSessions() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMo
|
|||||||
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
||||||
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
||||||
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
||||||
|
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||||
|
|
||||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||||
cfg := mockOpenAIConfig{
|
cfg := mockOpenAIConfig{
|
||||||
|
|||||||
@@ -35,7 +35,25 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeOpenAIError(w, status, detail)
|
writeOpenAIError(w, status, detail)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer h.Auth.Release(a)
|
defer func() {
|
||||||
|
// 自动删除会话(同步)
|
||||||
|
// 必须在 Release 之前同步删除,否则:
|
||||||
|
// 1. 异步删除时账号已被 Release
|
||||||
|
// 2. 新请求可能获取到同一账号并开始使用
|
||||||
|
// 3. 异步删除仍在进行,会截断新请求正在使用的会话
|
||||||
|
if h.Store.AutoDeleteSessions() && a.DeepSeekToken != "" {
|
||||||
|
deleteCtx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err := h.DS.DeleteAllSessionsForToken(deleteCtx, a.DeepSeekToken)
|
||||||
|
if err != nil {
|
||||||
|
config.Logger.Warn("[auto_delete_sessions] failed", "account", a.AccountID, "error", err)
|
||||||
|
} else {
|
||||||
|
config.Logger.Debug("[auto_delete_sessions] success", "account", a.AccountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.Auth.Release(a)
|
||||||
|
}()
|
||||||
|
|
||||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||||
|
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh
|
|||||||
if len(toolSchemas) == 0 {
|
if len(toolSchemas) == 0 {
|
||||||
return messages, names
|
return messages, names
|
||||||
}
|
}
|
||||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY a JSON code block like this:\n```json\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n```\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n```json\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n```\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON code block. The response must start with ```json and end with ```.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block.\n5) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n6) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets."
|
||||||
if policy.Mode == util.ToolChoiceRequired {
|
if policy.Mode == util.ToolChoiceRequired {
|
||||||
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -210,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||||
@@ -233,15 +234,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
msg, _ := choice["message"].(map[string]any)
|
msg, _ := choice["message"].(map[string]any)
|
||||||
toolCalls, _ := msg["tool_calls"].([]any)
|
toolCalls, _ := msg["tool_calls"].([]any)
|
||||||
if len(toolCalls) == 0 {
|
if len(toolCalls) != 1 {
|
||||||
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
|
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
|
||||||
}
|
}
|
||||||
if msg["content"] != nil {
|
content, _ := msg["content"].(string)
|
||||||
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
|
if strings.Contains(content, `"tool_calls"`) {
|
||||||
|
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||||
@@ -257,16 +259,17 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
|||||||
out := decodeJSONBody(t, rec.Body.String())
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
choices, _ := out["choices"].([]any)
|
choices, _ := out["choices"].([]any)
|
||||||
choice, _ := choices[0].(map[string]any)
|
choice, _ := choices[0].(map[string]any)
|
||||||
if choice["finish_reason"] != "stop" {
|
if choice["finish_reason"] != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||||
}
|
}
|
||||||
msg, _ := choice["message"].(map[string]any)
|
msg, _ := choice["message"].(map[string]any)
|
||||||
if _, ok := msg["tool_calls"]; ok {
|
toolCalls, _ := msg["tool_calls"].([]any)
|
||||||
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
|
if len(toolCalls) != 1 {
|
||||||
|
t.Fatalf("expected one tool_call field for fenced example: %#v", msg["tool_calls"])
|
||||||
}
|
}
|
||||||
content, _ := msg["content"].(string)
|
content, _ := msg["content"].(string)
|
||||||
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
|
if strings.Contains(content, `"tool_calls"`) {
|
||||||
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
|
t.Fatalf("expected raw tool_calls json stripped from content, got %q", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,6 +318,36 @@ func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
large := strings.Repeat("a", 9000)
|
||||||
|
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
|
||||||
|
splitAt := len(payload) / 2
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamHasRawToolJSONContent(frames) {
|
||||||
|
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
@@ -500,15 +533,12 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
|||||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||||
}
|
}
|
||||||
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
|
||||||
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
|
|
||||||
}
|
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||||
@@ -542,15 +572,13 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
|||||||
if !strings.Contains(got, "我将调用工具。") {
|
if !strings.Contains(got, "我将调用工具。") {
|
||||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||||
}
|
}
|
||||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
|
||||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
|
||||||
}
|
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
|
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||||
@@ -583,15 +611,52 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
|
|||||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
content := strings.Builder{}
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
if c, ok := delta["content"].(string); ok {
|
||||||
|
content.WriteString(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := content.String()
|
||||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
|
||||||
}
|
}
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
spaces := strings.Repeat(" ", 200)
|
spaces := strings.Repeat(" ", 200)
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
@@ -612,9 +677,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
|||||||
if !streamHasToolCallsDelta(frames) {
|
if !streamHasToolCallsDelta(frames) {
|
||||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
if streamHasRawToolJSONContent(frames) {
|
|
||||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
|
||||||
}
|
|
||||||
content := strings.Builder{}
|
content := strings.Builder{}
|
||||||
for _, frame := range frames {
|
for _, frame := range frames {
|
||||||
choices, _ := frame["choices"].([]any)
|
choices, _ := frame["choices"].([]any)
|
||||||
@@ -627,9 +689,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
got := content.String()
|
got := content.String()
|
||||||
if strings.Contains(got, "{") {
|
|
||||||
t.Fatalf("unexpected suspicious prefix leak in content: %q", got)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got, "后置正文C。") {
|
if !strings.Contains(got, "后置正文C。") {
|
||||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||||
}
|
}
|
||||||
@@ -712,7 +771,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||||
@@ -735,8 +794,8 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
|||||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||||
}
|
}
|
||||||
argChunks := streamToolCallArgumentChunks(frames)
|
argChunks := streamToolCallArgumentChunks(frames)
|
||||||
if len(argChunks) < 2 {
|
if len(argChunks) == 0 {
|
||||||
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||||
}
|
}
|
||||||
joined := strings.Join(argChunks, "")
|
joined := strings.Join(argChunks, "")
|
||||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||||
@@ -34,9 +34,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": formatToolResultForPrompt(msg),
|
"content": formatToolResultForPrompt(msg),
|
||||||
})
|
})
|
||||||
case "user", "system":
|
case "user", "system", "developer":
|
||||||
out = append(out, map[string]any{
|
out = append(out, map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
@@ -48,7 +48,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
|||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
out = append(out, map[string]any{
|
out = append(out, map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": content,
|
"content": content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -78,7 +78,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) strin
|
|||||||
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||||
}
|
}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = "unknown"
|
continue
|
||||||
}
|
}
|
||||||
if args == "" {
|
if args == "" {
|
||||||
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||||
@@ -133,32 +133,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeOpenAIContentForPrompt(v any) string {
|
func normalizeOpenAIContentForPrompt(v any) string {
|
||||||
switch x := v.(type) {
|
return prompt.NormalizeContent(v)
|
||||||
case string:
|
|
||||||
return x
|
|
||||||
case []any:
|
|
||||||
parts := make([]string, 0, len(x))
|
|
||||||
for _, item := range x {
|
|
||||||
m, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
|
||||||
if t != "text" && t != "output_text" && t != "input_text" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if text := asString(m["text"]); text != "" {
|
|
||||||
parts = append(parts, text)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if text := asString(m["content"]); text != "" {
|
|
||||||
parts = append(parts, text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(parts, "\n")
|
|
||||||
default:
|
|
||||||
return marshalToPromptString(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||||
@@ -175,30 +150,11 @@ func normalizeToolArgumentString(raw string) string {
|
|||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if !looksLikeConcatenatedJSON(trimmed) {
|
if looksLikeConcatenatedJSON(trimmed) {
|
||||||
return trimmed
|
// Keep original payload to avoid silent argument rewrites.
|
||||||
|
return raw
|
||||||
}
|
}
|
||||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
return trimmed
|
||||||
values := make([]any, 0, 2)
|
|
||||||
for {
|
|
||||||
var v any
|
|
||||||
if err := dec.Decode(&v); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
values = append(values, v)
|
|
||||||
}
|
|
||||||
if len(values) < 2 {
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
last := values[len(values)-1]
|
|
||||||
b, err := json.Marshal(last)
|
|
||||||
if err != nil || len(b) == 0 {
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
return string(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalToPromptString(v any) string {
|
func marshalToPromptString(v any) string {
|
||||||
@@ -209,6 +165,14 @@ func marshalToPromptString(v any) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||||
|
role = strings.ToLower(strings.TrimSpace(role))
|
||||||
|
if role == "developer" {
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
|
||||||
func asString(v any) string {
|
func asString(v any) string {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
return s
|
return s
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) {
|
func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) {
|
||||||
raw := []any{
|
raw := []any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -189,10 +189,94 @@ func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *te
|
|||||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
}
|
}
|
||||||
content, _ := normalized[0]["content"].(string)
|
content, _ := normalized[0]["content"].(string)
|
||||||
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) {
|
if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
|
||||||
t.Fatalf("expected repaired arguments in tool history, got %q", content)
|
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||||
}
|
}
|
||||||
if strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
}
|
||||||
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
|
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_missing_name",
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"arguments": `{"path":"README.MD"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 0 {
|
||||||
|
t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": nil,
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_screenshot",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "send_file_to_user",
|
||||||
|
"arguments": `{"file_path":"/tmp/a.png"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
content, _ := normalized[0]["content"].(string)
|
||||||
|
if strings.Contains(content, "<|Assistant|>null") || strings.HasPrefix(strings.TrimSpace(content), "null") {
|
||||||
|
t.Fatalf("unexpected null literal injected into assistant tool history: %q", content)
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "function.name: send_file_to_user") {
|
||||||
|
t.Fatalf("expected tool history block preserved, got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{"role": "developer", "content": "必须先走工具调用"},
|
||||||
|
map[string]any{"role": "user", "content": "你好"},
|
||||||
|
}
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 2 {
|
||||||
|
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
if normalized[0]["role"] != "system" {
|
||||||
|
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextEmpty(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "", "content": "工具说明文本"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
content, _ := normalized[0]["content"].(string)
|
||||||
|
if content != "工具说明文本" {
|
||||||
|
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) {
|
func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) {
|
||||||
msgs := normalizeResponsesInputAsMessages([]any{
|
msgs := normalizeResponsesInputAsMessages([]any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
@@ -151,8 +151,8 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArg
|
|||||||
toolCalls, _ := m["tool_calls"].([]any)
|
toolCalls, _ := m["tool_calls"].([]any)
|
||||||
call, _ := toolCalls[0].(map[string]any)
|
call, _ := toolCalls[0].(map[string]any)
|
||||||
fn, _ := call["function"].(map[string]any)
|
fn, _ := call["function"].(map[string]any)
|
||||||
if fn["arguments"] != `{"q":"golang"}` {
|
if fn["arguments"] != `{}{"q":"golang"}` {
|
||||||
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"])
|
t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -113,15 +113,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||||
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
|
textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames)
|
||||||
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
|
|
||||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||||
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
|
|
||||||
|
|
||||||
callCount := len(textParsed.Calls)
|
callCount := len(textParsed.Calls)
|
||||||
if callCount == 0 {
|
|
||||||
callCount = len(thinkingParsed.Calls)
|
|
||||||
}
|
|
||||||
if toolChoice.IsRequired() && callCount == 0 {
|
if toolChoice.IsRequired() && callCount == 0 {
|
||||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
|||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
case "function_call_output", "tool_result":
|
case "function_call_output", "tool_result":
|
||||||
|
|||||||
@@ -102,16 +102,11 @@ func (s *responsesStreamRuntime) finalize() {
|
|||||||
|
|
||||||
if s.bufferToolContent {
|
if s.bufferToolContent {
|
||||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
|
textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||||
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
|
|
||||||
detected := textParsed.Calls
|
detected := textParsed.Calls
|
||||||
if len(detected) == 0 {
|
s.logToolPolicyRejections(textParsed)
|
||||||
detected = thinkingParsed.Calls
|
|
||||||
}
|
|
||||||
s.logToolPolicyRejections(textParsed, thinkingParsed)
|
|
||||||
|
|
||||||
if len(detected) > 0 {
|
if len(detected) > 0 {
|
||||||
s.toolCallsEmitted = true
|
s.toolCallsEmitted = true
|
||||||
@@ -157,7 +152,7 @@ func (s *responsesStreamRuntime) finalize() {
|
|||||||
s.sendDone()
|
s.sendDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) {
|
||||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||||
@@ -172,7 +167,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
logRejected(textParsed, "text")
|
logRejected(textParsed, "text")
|
||||||
logRejected(thinkingParsed, "thinking")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||||
@@ -207,9 +201,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
|||||||
}
|
}
|
||||||
s.thinking.WriteString(p.Text)
|
s.thinking.WriteString(p.Text)
|
||||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||||
if s.bufferToolContent {
|
|
||||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,16 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
|||||||
outputIndex := s.ensureMessageOutputIndex()
|
outputIndex := s.ensureMessageOutputIndex()
|
||||||
text := s.visibleText.String()
|
text := s.visibleText.String()
|
||||||
if s.messagePartAdded {
|
if s.messagePartAdded {
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_text.done",
|
||||||
|
openaifmt.BuildResponsesTextDonePayload(
|
||||||
|
s.responseID,
|
||||||
|
itemID,
|
||||||
|
outputIndex,
|
||||||
|
0,
|
||||||
|
text,
|
||||||
|
),
|
||||||
|
)
|
||||||
s.sendEvent(
|
s.sendEvent(
|
||||||
"response.content_part.done",
|
"response.content_part.done",
|
||||||
openaifmt.BuildResponsesContentPartDonePayload(
|
openaifmt.BuildResponsesContentPartDonePayload(
|
||||||
|
|||||||
@@ -99,9 +99,6 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
|||||||
if !strings.Contains(body, "event: response.output_item.done") {
|
if !strings.Contains(body, "event: response.output_item.done") {
|
||||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||||
}
|
}
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
|
||||||
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
|
||||||
}
|
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||||
}
|
}
|
||||||
@@ -229,6 +226,40 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.output_text.done") {
|
||||||
|
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||||
|
}
|
||||||
|
textDoneIdx := strings.Index(body, "event: response.output_text.done")
|
||||||
|
partDoneIdx := strings.Index(body, "event: response.content_part.done")
|
||||||
|
if textDoneIdx < 0 || partDoneIdx < 0 {
|
||||||
|
t.Fatalf("expected output_text.done + content_part.done, body=%s", body)
|
||||||
|
}
|
||||||
|
if textDoneIdx > partDoneIdx {
|
||||||
|
t.Fatalf("expected output_text.done before content_part.done, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
@@ -266,7 +297,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) {
|
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -291,23 +322,8 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
|
|||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||||
if len(addedPayloads) < 2 {
|
if len(addedPayloads) < 1 {
|
||||||
t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String())
|
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||||
}
|
|
||||||
|
|
||||||
indexes := map[int]struct{}{}
|
|
||||||
typeByIndex := map[int]string{}
|
|
||||||
addedIDs := map[string]string{}
|
|
||||||
for _, payload := range addedPayloads {
|
|
||||||
item, _ := payload["item"].(map[string]any)
|
|
||||||
itemType := strings.TrimSpace(asString(item["type"]))
|
|
||||||
outputIndex := int(asFloat(payload["output_index"]))
|
|
||||||
if _, exists := indexes[outputIndex]; exists {
|
|
||||||
t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload)
|
|
||||||
}
|
|
||||||
indexes[outputIndex] = struct{}{}
|
|
||||||
typeByIndex[outputIndex] = itemType
|
|
||||||
addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"]))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||||
@@ -316,20 +332,25 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
|
|||||||
}
|
}
|
||||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||||
output, _ := responseObj["output"].([]any)
|
output, _ := responseObj["output"].([]any)
|
||||||
found := map[string]bool{}
|
hasMessage := false
|
||||||
|
hasFunctionCall := false
|
||||||
for _, item := range output {
|
for _, item := range output {
|
||||||
m, _ := item.(map[string]any)
|
m, _ := item.(map[string]any)
|
||||||
itemType := strings.TrimSpace(asString(m["type"]))
|
if m == nil {
|
||||||
itemID := strings.TrimSpace(asString(m["id"]))
|
|
||||||
if itemType == "" || itemID == "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID {
|
if asString(m["type"]) == "message" {
|
||||||
found[itemType] = true
|
hasMessage = true
|
||||||
|
}
|
||||||
|
if asString(m["type"]) == "function_call" {
|
||||||
|
hasFunctionCall = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found["message"] || !found["function_call"] {
|
if !hasMessage {
|
||||||
t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output)
|
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||||
|
}
|
||||||
|
if !hasFunctionCall {
|
||||||
|
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,7 +381,7 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
|
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -373,7 +394,7 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
|||||||
return "data: " + string(b) + "\n"
|
return "data: " + string(b) + "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
|
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusOK,
|
StatusCode: http.StatusOK,
|
||||||
@@ -382,14 +403,11 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
|||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
|
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||||
}
|
}
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||||
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
|
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||||
}
|
|
||||||
if !strings.Contains(body, "event: response.output_item.done") {
|
|
||||||
t.Fatalf("expected runtime to close function output item, body=%s", body)
|
|
||||||
}
|
}
|
||||||
if !strings.Contains(body, "event: response.completed") {
|
if !strings.Contains(body, "event: response.completed") {
|
||||||
t.Fatalf("expected response.completed event, body=%s", body)
|
t.Fatalf("expected response.completed event, body=%s", body)
|
||||||
@@ -430,6 +448,42 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(path, value string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": path,
|
||||||
|
"v": value,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||||
|
sseLine("response/content", "plain text only") +
|
||||||
|
"data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.failed") {
|
||||||
|
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: response.completed") {
|
||||||
|
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
@@ -516,6 +570,33 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||||
|
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||||
|
`data: [DONE]` + "\n",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
|
||||||
|
if rec.Code != http.StatusUnprocessableEntity {
|
||||||
|
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||||
|
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -53,6 +53,10 @@ func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAut
|
|||||||
return m.resp, nil
|
return m.resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m streamStatusDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
||||||
body := strings.Join(lines, "\n")
|
body := strings.Join(lines, "\n")
|
||||||
if !strings.HasSuffix(body, "\n") {
|
if !strings.HasSuffix(body, "\n") {
|
||||||
@@ -168,18 +172,14 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
|||||||
}
|
}
|
||||||
outputText, _ := out["output_text"].(string)
|
outputText, _ := out["output_text"].(string)
|
||||||
if outputText != "" {
|
if outputText != "" {
|
||||||
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
|
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||||
}
|
}
|
||||||
output, _ := out["output"].([]any)
|
output, _ := out["output"].([]any)
|
||||||
hasFunctionCall := false
|
if len(output) != 1 {
|
||||||
for _, item := range output {
|
t.Fatalf("expected one output item, got %#v", output)
|
||||||
m, _ := item.(map[string]any)
|
|
||||||
if m != nil && m["type"] == "function_call" {
|
|
||||||
hasFunctionCall = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if !hasFunctionCall {
|
first, _ := output[0].(map[string]any)
|
||||||
|
if first["type"] != "function_call" {
|
||||||
t.Fatalf("expected function_call output item, got %#v", output)
|
t.Fatalf("expected function_call output item, got %#v", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
state.pending.WriteString(chunk)
|
state.pending.WriteString(chunk)
|
||||||
}
|
}
|
||||||
events := make([]toolStreamEvent, 0, 2)
|
events := make([]toolStreamEvent, 0, 2)
|
||||||
|
if len(state.pendingToolCalls) > 0 {
|
||||||
|
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||||
|
state.pendingToolRaw = ""
|
||||||
|
state.pendingToolCalls = nil
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if state.capturing {
|
if state.capturing {
|
||||||
@@ -21,32 +26,30 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
state.capture.WriteString(state.pending.String())
|
state.capture.WriteString(state.pending.String())
|
||||||
state.pending.Reset()
|
state.pending.Reset()
|
||||||
}
|
}
|
||||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
|
||||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
|
||||||
}
|
|
||||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||||
if !ready {
|
if !ready {
|
||||||
if state.capture.Len() > toolSieveCaptureLimit {
|
|
||||||
content := state.capture.String()
|
|
||||||
state.capture.Reset()
|
|
||||||
state.capturing = false
|
|
||||||
state.resetIncrementalToolState()
|
|
||||||
state.noteText(content)
|
|
||||||
events = append(events, toolStreamEvent{Content: content})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
captured := state.capture.String()
|
||||||
state.capture.Reset()
|
state.capture.Reset()
|
||||||
state.capturing = false
|
state.capturing = false
|
||||||
state.resetIncrementalToolState()
|
state.resetIncrementalToolState()
|
||||||
|
if len(calls) > 0 {
|
||||||
|
if prefix != "" {
|
||||||
|
state.noteText(prefix)
|
||||||
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
|
}
|
||||||
|
if suffix != "" {
|
||||||
|
state.pending.WriteString(suffix)
|
||||||
|
}
|
||||||
|
_ = captured
|
||||||
|
state.pendingToolCalls = calls
|
||||||
|
continue
|
||||||
|
}
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
state.noteText(prefix)
|
state.noteText(prefix)
|
||||||
events = append(events, toolStreamEvent{Content: prefix})
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
}
|
}
|
||||||
if len(calls) > 0 {
|
|
||||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
|
||||||
}
|
|
||||||
if suffix != "" {
|
if suffix != "" {
|
||||||
state.pending.WriteString(suffix)
|
state.pending.WriteString(suffix)
|
||||||
}
|
}
|
||||||
@@ -89,6 +92,11 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
events := processToolSieveChunk(state, "", toolNames)
|
events := processToolSieveChunk(state, "", toolNames)
|
||||||
|
if len(state.pendingToolCalls) > 0 {
|
||||||
|
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||||
|
state.pendingToolRaw = ""
|
||||||
|
state.pendingToolCalls = nil
|
||||||
|
}
|
||||||
if state.capturing {
|
if state.capturing {
|
||||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||||
if ready {
|
if ready {
|
||||||
@@ -159,22 +167,22 @@ func findToolSegmentStart(s string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
lower := strings.ToLower(s)
|
lower := strings.ToLower(s)
|
||||||
offset := 0
|
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
|
||||||
for {
|
bestKeyIdx := -1
|
||||||
keyRel := strings.Index(lower[offset:], "tool_calls")
|
for _, kw := range keywords {
|
||||||
if keyRel < 0 {
|
idx := strings.Index(lower, kw)
|
||||||
return -1
|
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||||
|
bestKeyIdx = idx
|
||||||
}
|
}
|
||||||
keyIdx := offset + keyRel
|
|
||||||
start := strings.LastIndex(s[:keyIdx], "{")
|
|
||||||
if start < 0 {
|
|
||||||
start = keyIdx
|
|
||||||
}
|
|
||||||
if !insideCodeFence(s[:start]) {
|
|
||||||
return start
|
|
||||||
}
|
|
||||||
offset = keyIdx + len("tool_calls")
|
|
||||||
}
|
}
|
||||||
|
if bestKeyIdx < 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
start := strings.LastIndex(s[:bestKeyIdx], "{")
|
||||||
|
if start < 0 {
|
||||||
|
start = bestKeyIdx
|
||||||
|
}
|
||||||
|
return start
|
||||||
}
|
}
|
||||||
|
|
||||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||||
@@ -183,13 +191,22 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
|||||||
return "", nil, "", false
|
return "", nil, "", false
|
||||||
}
|
}
|
||||||
lower := strings.ToLower(captured)
|
lower := strings.ToLower(captured)
|
||||||
keyIdx := strings.Index(lower, "tool_calls")
|
|
||||||
|
keyIdx := -1
|
||||||
|
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
idx := strings.Index(lower, kw)
|
||||||
|
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||||
|
keyIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if keyIdx < 0 {
|
if keyIdx < 0 {
|
||||||
return "", nil, "", false
|
return "", nil, "", false
|
||||||
}
|
}
|
||||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||||
if start < 0 {
|
if start < 0 {
|
||||||
return "", nil, "", false
|
start = keyIdx
|
||||||
}
|
}
|
||||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -197,9 +214,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
|||||||
}
|
}
|
||||||
prefixPart := captured[:start]
|
prefixPart := captured[:start]
|
||||||
suffixPart := captured[end:]
|
suffixPart := captured[end:]
|
||||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
|
||||||
return captured, nil, "", true
|
|
||||||
}
|
|
||||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||||
if len(parsed.Calls) == 0 {
|
if len(parsed.Calls) == 0 {
|
||||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||||
@@ -207,6 +221,9 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
|||||||
// consume it to avoid leaking raw tool_calls JSON to user content.
|
// consume it to avoid leaking raw tool_calls JSON to user content.
|
||||||
return prefixPart, nil, suffixPart, true
|
return prefixPart, nil, suffixPart, true
|
||||||
}
|
}
|
||||||
|
// If it has obvious keywords but failed to parse even after loose repair,
|
||||||
|
// we still might want to intercept it if it looks like an attempt at tool call.
|
||||||
|
// For now, keep the original logic but rely on loose JSON repair.
|
||||||
return captured, nil, "", true
|
return captured, nil, "", true
|
||||||
}
|
}
|
||||||
return prefixPart, parsed.Calls, suffixPart, true
|
return prefixPart, parsed.Calls, suffixPart, true
|
||||||
|
|||||||
@@ -19,9 +19,6 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
|||||||
if start < 0 {
|
if start < 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||||
if hasMultiple {
|
if hasMultiple {
|
||||||
state.disableDeltas = true
|
state.disableDeltas = true
|
||||||
|
|||||||
@@ -7,17 +7,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type toolStreamSieveState struct {
|
type toolStreamSieveState struct {
|
||||||
pending strings.Builder
|
pending strings.Builder
|
||||||
capture strings.Builder
|
capture strings.Builder
|
||||||
capturing bool
|
capturing bool
|
||||||
recentTextTail string
|
recentTextTail string
|
||||||
disableDeltas bool
|
pendingToolRaw string
|
||||||
toolNameSent bool
|
pendingToolCalls []util.ParsedToolCall
|
||||||
toolName string
|
disableDeltas bool
|
||||||
toolArgsStart int
|
toolNameSent bool
|
||||||
toolArgsSent int
|
toolName string
|
||||||
toolArgsString bool
|
toolArgsStart int
|
||||||
toolArgsDone bool
|
toolArgsSent int
|
||||||
|
toolArgsString bool
|
||||||
|
toolArgsDone bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type toolStreamEvent struct {
|
type toolStreamEvent struct {
|
||||||
@@ -32,7 +34,6 @@ type toolCallDelta struct {
|
|||||||
Arguments string
|
Arguments string
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolSieveCaptureLimit = 8 * 1024
|
|
||||||
const toolSieveContextTailLimit = 256
|
const toolSieveContextTailLimit = 256
|
||||||
|
|
||||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type ConfigStore interface {
|
|||||||
RuntimeAccountMaxInflight() int
|
RuntimeAccountMaxInflight() int
|
||||||
RuntimeAccountMaxQueue(defaultSize int) int
|
RuntimeAccountMaxQueue(defaultSize int) int
|
||||||
RuntimeGlobalMaxInflight(defaultSize int) int
|
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||||
|
AutoDeleteSessions() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type PoolController interface {
|
type PoolController interface {
|
||||||
@@ -40,6 +41,8 @@ type DeepSeekCaller interface {
|
|||||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||||
|
GetSessionCountForToken(ctx context.Context, token string) (*deepseek.SessionStats, error)
|
||||||
|
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ConfigStore = (*config.Store)(nil)
|
var _ ConfigStore = (*config.Store)(nil)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
|||||||
pr.Get("/queue/status", h.queueStatus)
|
pr.Get("/queue/status", h.queueStatus)
|
||||||
pr.Post("/accounts/test", h.testSingleAccount)
|
pr.Post("/accounts/test", h.testSingleAccount)
|
||||||
pr.Post("/accounts/test-all", h.testAllAccounts)
|
pr.Post("/accounts/test-all", h.testAllAccounts)
|
||||||
|
pr.Post("/accounts/sessions/delete-all", h.deleteAllSessions)
|
||||||
pr.Post("/import", h.batchImport)
|
pr.Post("/import", h.batchImport)
|
||||||
pr.Post("/test", h.testAPI)
|
pr.Post("/test", h.testAPI)
|
||||||
pr.Post("/vercel/sync", h.syncVercel)
|
pr.Post("/vercel/sync", h.syncVercel)
|
||||||
|
|||||||
@@ -1,128 +1,133 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
"ds2api/internal/config"
|
|
||||||
)
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
|
||||||
page := intFromQuery(r, "page", 1)
|
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
pageSize := intFromQuery(r, "page_size", 10)
|
page := intFromQuery(r, "page", 1)
|
||||||
if page < 1 {
|
pageSize := intFromQuery(r, "page_size", 10)
|
||||||
page = 1
|
if page < 1 {
|
||||||
}
|
page = 1
|
||||||
if pageSize < 1 {
|
}
|
||||||
pageSize = 1
|
if pageSize < 1 {
|
||||||
}
|
pageSize = 1
|
||||||
if pageSize > 100 {
|
}
|
||||||
pageSize = 100
|
if pageSize > 100 {
|
||||||
}
|
pageSize = 100
|
||||||
accounts := h.Store.Snapshot().Accounts
|
}
|
||||||
reverseAccounts(accounts)
|
accounts := h.Store.Snapshot().Accounts
|
||||||
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
|
reverseAccounts(accounts)
|
||||||
if q != "" {
|
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
|
||||||
filtered := make([]config.Account, 0, len(accounts))
|
if q != "" {
|
||||||
for _, acc := range accounts {
|
filtered := make([]config.Account, 0, len(accounts))
|
||||||
id := strings.ToLower(acc.Identifier())
|
for _, acc := range accounts {
|
||||||
if strings.Contains(id, q) ||
|
id := strings.ToLower(acc.Identifier())
|
||||||
strings.Contains(strings.ToLower(acc.Email), q) ||
|
if strings.Contains(id, q) ||
|
||||||
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
strings.Contains(strings.ToLower(acc.Email), q) ||
|
||||||
filtered = append(filtered, acc)
|
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
||||||
}
|
filtered = append(filtered, acc)
|
||||||
}
|
}
|
||||||
accounts = filtered
|
}
|
||||||
}
|
accounts = filtered
|
||||||
total := len(accounts)
|
}
|
||||||
totalPages := 1
|
total := len(accounts)
|
||||||
if total > 0 {
|
totalPages := 1
|
||||||
totalPages = (total + pageSize - 1) / pageSize
|
if total > 0 {
|
||||||
}
|
totalPages = (total + pageSize - 1) / pageSize
|
||||||
start := (page - 1) * pageSize
|
}
|
||||||
if start > total {
|
start := (page - 1) * pageSize
|
||||||
start = total
|
if start > total {
|
||||||
}
|
start = total
|
||||||
end := start + pageSize
|
}
|
||||||
if end > total {
|
end := start + pageSize
|
||||||
end = total
|
if end > total {
|
||||||
}
|
end = total
|
||||||
items := make([]map[string]any, 0, end-start)
|
}
|
||||||
for _, acc := range accounts[start:end] {
|
items := make([]map[string]any, 0, end-start)
|
||||||
token := strings.TrimSpace(acc.Token)
|
for _, acc := range accounts[start:end] {
|
||||||
preview := ""
|
token := strings.TrimSpace(acc.Token)
|
||||||
if token != "" {
|
preview := ""
|
||||||
if len(token) > 20 {
|
if token != "" {
|
||||||
preview = token[:20] + "..."
|
if len(token) > 20 {
|
||||||
} else {
|
preview = token[:20] + "..."
|
||||||
preview = token
|
} else {
|
||||||
}
|
preview = token
|
||||||
}
|
}
|
||||||
items = append(items, map[string]any{
|
}
|
||||||
"identifier": acc.Identifier(),
|
items = append(items, map[string]any{
|
||||||
"email": acc.Email,
|
"identifier": acc.Identifier(),
|
||||||
"mobile": acc.Mobile,
|
"email": acc.Email,
|
||||||
"has_password": acc.Password != "",
|
"mobile": acc.Mobile,
|
||||||
"has_token": token != "",
|
"has_password": acc.Password != "",
|
||||||
"token_preview": preview,
|
"has_token": token != "",
|
||||||
"test_status": acc.TestStatus,
|
"token_preview": preview,
|
||||||
})
|
"test_status": acc.TestStatus,
|
||||||
}
|
})
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
}
|
||||||
}
|
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||||
|
}
|
||||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var req map[string]any
|
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
var req map[string]any
|
||||||
acc := toAccount(req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if acc.Identifier() == "" {
|
acc := toAccount(req)
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
if acc.Identifier() == "" {
|
||||||
return
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||||
}
|
return
|
||||||
err := h.Store.Update(func(c *config.Config) error {
|
}
|
||||||
for _, a := range c.Accounts {
|
err := h.Store.Update(func(c *config.Config) error {
|
||||||
if acc.Email != "" && a.Email == acc.Email {
|
mobileKey := config.CanonicalMobileKey(acc.Mobile)
|
||||||
return fmt.Errorf("邮箱已存在")
|
for _, a := range c.Accounts {
|
||||||
}
|
if acc.Email != "" && a.Email == acc.Email {
|
||||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
return fmt.Errorf("邮箱已存在")
|
||||||
return fmt.Errorf("手机号已存在")
|
}
|
||||||
}
|
if mobileKey != "" && config.CanonicalMobileKey(a.Mobile) == mobileKey {
|
||||||
}
|
return fmt.Errorf("手机号已存在")
|
||||||
c.Accounts = append(c.Accounts, acc)
|
}
|
||||||
return nil
|
}
|
||||||
})
|
c.Accounts = append(c.Accounts, acc)
|
||||||
if err != nil {
|
return nil
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
})
|
||||||
return
|
if err != nil {
|
||||||
}
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||||
h.Pool.Reset()
|
return
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
}
|
||||||
}
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
}
|
||||||
identifier := chi.URLParam(r, "identifier")
|
|
||||||
err := h.Store.Update(func(c *config.Config) error {
|
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
idx := -1
|
identifier := chi.URLParam(r, "identifier")
|
||||||
for i, a := range c.Accounts {
|
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||||
if accountMatchesIdentifier(a, identifier) {
|
identifier = decoded
|
||||||
idx = i
|
}
|
||||||
break
|
err := h.Store.Update(func(c *config.Config) error {
|
||||||
}
|
idx := -1
|
||||||
}
|
for i, a := range c.Accounts {
|
||||||
if idx < 0 {
|
if accountMatchesIdentifier(a, identifier) {
|
||||||
return fmt.Errorf("账号不存在")
|
idx = i
|
||||||
}
|
break
|
||||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
}
|
||||||
return nil
|
}
|
||||||
})
|
if idx < 0 {
|
||||||
if err != nil {
|
return fmt.Errorf("账号不存在")
|
||||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
}
|
||||||
return
|
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||||
}
|
return nil
|
||||||
h.Pool.Reset()
|
})
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
if err != nil {
|
||||||
}
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -102,6 +103,45 @@ func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteAccountSupportsEncodedPlusMobile(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape("+8613800138000"), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 0 {
|
||||||
|
t.Fatalf("expected account removed, remaining=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Post("/admin/accounts", h.addAccount)
|
||||||
|
body := []byte(`{"mobile":"13800138000","password":"pwd2"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/admin/accounts", bytes.NewReader(body))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 1 {
|
||||||
|
t.Fatalf("expected no duplicate insert, got=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||||
h := newAdminTestHandler(t, `{
|
h := newAdminTestHandler(t, `{
|
||||||
"accounts":[
|
"accounts":[
|
||||||
@@ -117,6 +157,13 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
|||||||
if accByMobile.Email != "u@example.com" {
|
if accByMobile.Email != "u@example.com" {
|
||||||
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
||||||
}
|
}
|
||||||
|
accByMobileWithCountryCode, ok := findAccountByIdentifier(h.Store, "+8613800138000")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected find by +86 mobile")
|
||||||
|
}
|
||||||
|
if accByMobileWithCountryCode.Email != "u@example.com" {
|
||||||
|
t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode)
|
||||||
|
}
|
||||||
|
|
||||||
tokenOnlyID := ""
|
tokenOnlyID := ""
|
||||||
for _, acc := range h.Store.Accounts() {
|
for _, acc := range h.Store.Accounts() {
|
||||||
|
|||||||
@@ -1,209 +1,265 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
authn "ds2api/internal/auth"
|
authn "ds2api/internal/auth"
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
identifier, _ := req["identifier"].(string)
|
identifier, _ := req["identifier"].(string)
|
||||||
if strings.TrimSpace(identifier) == "" {
|
if strings.TrimSpace(identifier) == "" {
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||||
if !ok {
|
if !ok {
|
||||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model, _ := req["model"].(string)
|
model, _ := req["model"].(string)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = "deepseek-chat"
|
model = "deepseek-chat"
|
||||||
}
|
}
|
||||||
message, _ := req["message"].(string)
|
message, _ := req["message"].(string)
|
||||||
result := h.testAccount(r.Context(), acc, model, message)
|
result := h.testAccount(r.Context(), acc, model, message)
|
||||||
writeJSON(w, http.StatusOK, result)
|
writeJSON(w, http.StatusOK, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
model, _ := req["model"].(string)
|
model, _ := req["model"].(string)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = "deepseek-chat"
|
model = "deepseek-chat"
|
||||||
}
|
}
|
||||||
accounts := h.Store.Snapshot().Accounts
|
accounts := h.Store.Snapshot().Accounts
|
||||||
if len(accounts) == 0 {
|
if len(accounts) == 0 {
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concurrent testing with a semaphore to limit parallelism.
|
// Concurrent testing with a semaphore to limit parallelism.
|
||||||
const maxConcurrency = 5
|
const maxConcurrency = 5
|
||||||
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
||||||
return h.testAccount(r.Context(), account, model, "")
|
return h.testAccount(r.Context(), account, model, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
success := 0
|
success := 0
|
||||||
for _, res := range results {
|
for _, res := range results {
|
||||||
if ok, _ := res["success"].(bool); ok {
|
if ok, _ := res["success"].(bool); ok {
|
||||||
success++
|
success++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||||
}
|
}
|
||||||
|
|
||||||
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
||||||
if maxConcurrency <= 0 {
|
if maxConcurrency <= 0 {
|
||||||
maxConcurrency = 1
|
maxConcurrency = 1
|
||||||
}
|
}
|
||||||
sem := make(chan struct{}, maxConcurrency)
|
sem := make(chan struct{}, maxConcurrency)
|
||||||
results := make([]map[string]any, len(accounts))
|
results := make([]map[string]any, len(accounts))
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i, acc := range accounts {
|
for i, acc := range accounts {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(idx int, account config.Account) {
|
go func(idx int, account config.Account) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
sem <- struct{}{} // acquire
|
sem <- struct{}{} // acquire
|
||||||
defer func() { <-sem }() // release
|
defer func() { <-sem }() // release
|
||||||
results[idx] = testFn(idx, account)
|
results[idx] = testFn(idx, account)
|
||||||
}(i, acc)
|
}(i, acc)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
identifier := acc.Identifier()
|
identifier := acc.Identifier()
|
||||||
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model, "session_count": 0}
|
||||||
defer func() {
|
defer func() {
|
||||||
status := "failed"
|
status := "failed"
|
||||||
if ok, _ := result["success"].(bool); ok {
|
if ok, _ := result["success"].(bool); ok {
|
||||||
status = "ok"
|
status = "ok"
|
||||||
}
|
}
|
||||||
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
||||||
}()
|
}()
|
||||||
token := strings.TrimSpace(acc.Token)
|
token := strings.TrimSpace(acc.Token)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
newToken, err := h.DS.Login(ctx, acc)
|
newToken, err := h.DS.Login(ctx, acc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
result["message"] = "登录失败: " + err.Error()
|
result["message"] = "登录失败: " + err.Error()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
token = newToken
|
token = newToken
|
||||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||||
}
|
}
|
||||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newToken, loginErr := h.DS.Login(ctx, acc)
|
newToken, loginErr := h.DS.Login(ctx, acc)
|
||||||
if loginErr != nil {
|
if loginErr != nil {
|
||||||
result["message"] = "创建会话失败: " + err.Error()
|
result["message"] = "创建会话失败: " + err.Error()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
token = newToken
|
token = newToken
|
||||||
authCtx.DeepSeekToken = token
|
authCtx.DeepSeekToken = token
|
||||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
result["message"] = "创建会话失败: " + err.Error()
|
result["message"] = "创建会话失败: " + err.Error()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(message) == "" {
|
|
||||||
message = "你是谁?"
|
// 获取会话数量
|
||||||
}
|
sessionStats, sessionErr := h.DS.GetSessionCountForToken(ctx, token)
|
||||||
thinking, search, ok := config.GetModelConfig(model)
|
if sessionErr == nil && sessionStats != nil {
|
||||||
if !ok {
|
result["session_count"] = sessionStats.FirstPageCount
|
||||||
thinking, search = false, false
|
}
|
||||||
}
|
|
||||||
_ = search
|
if strings.TrimSpace(message) == "" {
|
||||||
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
result["success"] = true
|
||||||
if err != nil {
|
result["message"] = "API 测试成功(仅会话创建)"
|
||||||
result["message"] = "获取 PoW 失败: " + err.Error()
|
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
thinking, search, ok := config.GetModelConfig(model)
|
||||||
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
if !ok {
|
||||||
if err != nil {
|
thinking, search = false, false
|
||||||
result["message"] = "请求失败: " + err.Error()
|
}
|
||||||
return result
|
_ = search
|
||||||
}
|
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if err != nil {
|
||||||
defer resp.Body.Close()
|
result["message"] = "获取 PoW 失败: " + err.Error()
|
||||||
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
return result
|
||||||
return result
|
}
|
||||||
}
|
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
||||||
collected := sse.CollectStream(resp, thinking, true)
|
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
||||||
result["success"] = true
|
if err != nil {
|
||||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
result["message"] = "请求失败: " + err.Error()
|
||||||
if collected.Text != "" {
|
return result
|
||||||
result["message"] = collected.Text
|
}
|
||||||
} else {
|
if resp.StatusCode != http.StatusOK {
|
||||||
result["message"] = "(无回复内容)"
|
defer resp.Body.Close()
|
||||||
}
|
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
||||||
if collected.Thinking != "" {
|
return result
|
||||||
result["thinking"] = collected.Thinking
|
}
|
||||||
}
|
collected := sse.CollectStream(resp, thinking, true)
|
||||||
return result
|
result["success"] = true
|
||||||
}
|
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||||
|
if collected.Text != "" {
|
||||||
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
result["message"] = collected.Text
|
||||||
var req map[string]any
|
} else {
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
result["message"] = "(无回复内容)"
|
||||||
model, _ := req["model"].(string)
|
}
|
||||||
message, _ := req["message"].(string)
|
if collected.Thinking != "" {
|
||||||
apiKey, _ := req["api_key"].(string)
|
result["thinking"] = collected.Thinking
|
||||||
if model == "" {
|
}
|
||||||
model = "deepseek-chat"
|
return result
|
||||||
}
|
}
|
||||||
if message == "" {
|
|
||||||
message = "你好"
|
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
var req map[string]any
|
||||||
if apiKey == "" {
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
keys := h.Store.Snapshot().Keys
|
model, _ := req["model"].(string)
|
||||||
if len(keys) == 0 {
|
message, _ := req["message"].(string)
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
apiKey, _ := req["api_key"].(string)
|
||||||
return
|
if model == "" {
|
||||||
}
|
model = "deepseek-chat"
|
||||||
apiKey = keys[0]
|
}
|
||||||
}
|
if message == "" {
|
||||||
host := r.Host
|
message = "你好"
|
||||||
scheme := "http"
|
}
|
||||||
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
if apiKey == "" {
|
||||||
scheme = "https"
|
keys := h.Store.Snapshot().Keys
|
||||||
}
|
if len(keys) == 0 {
|
||||||
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
||||||
b, _ := json.Marshal(payload)
|
return
|
||||||
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
}
|
||||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
apiKey = keys[0]
|
||||||
request.Header.Set("Content-Type", "application/json")
|
}
|
||||||
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
host := r.Host
|
||||||
if err != nil {
|
scheme := "http"
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
||||||
return
|
scheme = "https"
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
||||||
body, _ := io.ReadAll(resp.Body)
|
b, _ := json.Marshal(payload)
|
||||||
if resp.StatusCode == http.StatusOK {
|
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
||||||
var parsed any
|
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
_ = json.Unmarshal(body, &parsed)
|
request.Header.Set("Content-Type", "application/json")
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
||||||
return
|
if err != nil {
|
||||||
}
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
return
|
||||||
}
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
var parsed any
|
||||||
|
_ = json.Unmarshal(body, &parsed)
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) deleteAllSessions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
identifier, _ := req["identifier"].(string)
|
||||||
|
if strings.TrimSpace(identifier) == "" {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||||
|
if !ok {
|
||||||
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 token
|
||||||
|
token := strings.TrimSpace(acc.Token)
|
||||||
|
if token == "" {
|
||||||
|
newToken, err := h.DS.Login(r.Context(), acc)
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "登录失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token = newToken
|
||||||
|
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除所有会话
|
||||||
|
err := h.DS.DeleteAllSessionsForToken(r.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
// token 可能过期,尝试重新登录并重试一次
|
||||||
|
newToken, loginErr := h.DS.Login(r.Context(), acc)
|
||||||
|
if loginErr != nil {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token = newToken
|
||||||
|
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||||
|
if retryErr := h.DS.DeleteAllSessionsForToken(r.Context(), token); retryErr != nil {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + retryErr.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "删除成功"})
|
||||||
|
}
|
||||||
|
|||||||
134
internal/admin/handler_accounts_testing_test.go
Normal file
134
internal/admin/handler_accounts_testing_test.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testingDSMock struct {
|
||||||
|
loginCalls int
|
||||||
|
createSessionCalls int
|
||||||
|
getPowCalls int
|
||||||
|
callCompletionCalls int
|
||||||
|
deleteAllSessionsCalls int
|
||||||
|
deleteAllSessionsError error
|
||||||
|
deleteAllSessionsErrorOnce bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) Login(_ context.Context, _ config.Account) (string, error) {
|
||||||
|
m.loginCalls++
|
||||||
|
return "new-token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
m.createSessionCalls++
|
||||||
|
return "session-id", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
m.getPowCalls++
|
||||||
|
return "", errors.New("should not call GetPow in this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||||
|
m.callCompletionCalls++
|
||||||
|
return nil, errors.New("should not call CallCompletion in this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||||
|
m.deleteAllSessionsCalls++
|
||||||
|
if m.deleteAllSessionsError != nil {
|
||||||
|
err := m.deleteAllSessionsError
|
||||||
|
if m.deleteAllSessionsErrorOnce {
|
||||||
|
m.deleteAllSessionsError = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) GetSessionCountForToken(_ context.Context, _ string) (*deepseek.SessionStats, error) {
|
||||||
|
return &deepseek.SessionStats{Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":""}]}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
ds := &testingDSMock{}
|
||||||
|
h := &Handler{Store: store, DS: ds}
|
||||||
|
acc, ok := store.FindAccount("batch@example.com")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected test account")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := h.testAccount(context.Background(), acc, "deepseek-chat", "")
|
||||||
|
|
||||||
|
if ok, _ := result["success"].(bool); !ok {
|
||||||
|
t.Fatalf("expected success=true, got %#v", result)
|
||||||
|
}
|
||||||
|
msg, _ := result["message"].(string)
|
||||||
|
if !strings.Contains(msg, "仅会话创建") {
|
||||||
|
t.Fatalf("expected session-only success message, got %q", msg)
|
||||||
|
}
|
||||||
|
if ds.loginCalls != 1 || ds.createSessionCalls != 1 {
|
||||||
|
t.Fatalf("unexpected Login/CreateSession calls: login=%d createSession=%d", ds.loginCalls, ds.createSessionCalls)
|
||||||
|
}
|
||||||
|
if ds.getPowCalls != 0 || ds.callCompletionCalls != 0 {
|
||||||
|
t.Fatalf("expected no completion flow calls, got getPow=%d callCompletion=%d", ds.getPowCalls, ds.callCompletionCalls)
|
||||||
|
}
|
||||||
|
updated, ok := store.FindAccount("batch@example.com")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected updated account")
|
||||||
|
}
|
||||||
|
if updated.Token != "new-token" {
|
||||||
|
t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token)
|
||||||
|
}
|
||||||
|
if updated.TestStatus != "ok" {
|
||||||
|
t.Fatalf("expected test status ok, got %q", updated.TestStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAllSessions_RetryWithReloginOnDeleteFailure(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":"expired-token"}]}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
ds := &testingDSMock{deleteAllSessionsError: errors.New("token expired"), deleteAllSessionsErrorOnce: true}
|
||||||
|
h := &Handler{Store: store, DS: ds}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/delete-all", bytes.NewBufferString(`{"identifier":"batch@example.com"}`))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.deleteAllSessions(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if ok, _ := resp["success"].(bool); !ok {
|
||||||
|
t.Fatalf("expected success response, got %#v", resp)
|
||||||
|
}
|
||||||
|
if ds.loginCalls != 1 {
|
||||||
|
t.Fatalf("expected relogin once, got %d", ds.loginCalls)
|
||||||
|
}
|
||||||
|
if ds.deleteAllSessionsCalls != 2 {
|
||||||
|
t.Fatalf("expected delete called twice, got %d", ds.deleteAllSessionsCalls)
|
||||||
|
}
|
||||||
|
updated, ok := store.FindAccount("batch@example.com")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected account")
|
||||||
|
}
|
||||||
|
if updated.Token != "new-token" {
|
||||||
|
t.Fatalf("expected refreshed token persisted, got %q", updated.Token)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -49,6 +49,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
next := c.Clone()
|
next := c.Clone()
|
||||||
if mode == "replace" {
|
if mode == "replace" {
|
||||||
next = incoming.Clone()
|
next = incoming.Clone()
|
||||||
|
next.Accounts = normalizeAndDedupeAccounts(next.Accounts)
|
||||||
next.VercelSyncHash = c.VercelSyncHash
|
next.VercelSyncHash = c.VercelSyncHash
|
||||||
next.VercelSyncTime = c.VercelSyncTime
|
next.VercelSyncTime = c.VercelSyncTime
|
||||||
importedKeys = len(next.Keys)
|
importedKeys = len(next.Keys)
|
||||||
@@ -73,17 +74,22 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
existingAccounts := map[string]struct{}{}
|
existingAccounts := map[string]struct{}{}
|
||||||
for _, acc := range next.Accounts {
|
for _, acc := range next.Accounts {
|
||||||
existingAccounts[acc.Identifier()] = struct{}{}
|
acc = normalizeAccountForStorage(acc)
|
||||||
|
key := accountDedupeKey(acc)
|
||||||
|
if key != "" {
|
||||||
|
existingAccounts[key] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, acc := range incoming.Accounts {
|
for _, acc := range incoming.Accounts {
|
||||||
id := acc.Identifier()
|
acc = normalizeAccountForStorage(acc)
|
||||||
if id == "" {
|
key := accountDedupeKey(acc)
|
||||||
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, ok := existingAccounts[id]; ok {
|
if _, ok := existingAccounts[key]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
existingAccounts[id] = struct{}{}
|
existingAccounts[key] = struct{}{}
|
||||||
next.Accounts = append(next.Accounts, acc)
|
next.Accounts = append(next.Accounts, acc)
|
||||||
importedAccounts++
|
importedAccounts++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,17 +25,28 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||||
existing := map[string]config.Account{}
|
existing := map[string]config.Account{}
|
||||||
for _, a := range old.Accounts {
|
for _, a := range old.Accounts {
|
||||||
existing[a.Identifier()] = a
|
a = normalizeAccountForStorage(a)
|
||||||
|
key := accountDedupeKey(a)
|
||||||
|
if key != "" {
|
||||||
|
existing[key] = a
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
seen := map[string]struct{}{}
|
||||||
accounts := make([]config.Account, 0, len(accountsRaw))
|
accounts := make([]config.Account, 0, len(accountsRaw))
|
||||||
for _, item := range accountsRaw {
|
for _, item := range accountsRaw {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
acc := toAccount(m)
|
acc := normalizeAccountForStorage(toAccount(m))
|
||||||
id := acc.Identifier()
|
key := accountDedupeKey(acc)
|
||||||
if prev, ok := existing[id]; ok {
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if prev, ok := existing[key]; ok {
|
||||||
if strings.TrimSpace(acc.Password) == "" {
|
if strings.TrimSpace(acc.Password) == "" {
|
||||||
acc.Password = prev.Password
|
acc.Password = prev.Password
|
||||||
}
|
}
|
||||||
@@ -43,6 +54,7 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
acc.Token = prev.Token
|
acc.Token = prev.Token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
accounts = append(accounts, acc)
|
accounts = append(accounts, acc)
|
||||||
}
|
}
|
||||||
c.Accounts = accounts
|
c.Accounts = accounts
|
||||||
@@ -138,20 +150,24 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
if accounts, ok := req["accounts"].([]any); ok {
|
if accounts, ok := req["accounts"].([]any); ok {
|
||||||
existing := map[string]bool{}
|
existing := map[string]bool{}
|
||||||
for _, a := range c.Accounts {
|
for _, a := range c.Accounts {
|
||||||
existing[a.Identifier()] = true
|
a = normalizeAccountForStorage(a)
|
||||||
|
key := accountDedupeKey(a)
|
||||||
|
if key != "" {
|
||||||
|
existing[key] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, item := range accounts {
|
for _, item := range accounts {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
acc := toAccount(m)
|
acc := normalizeAccountForStorage(toAccount(m))
|
||||||
id := acc.Identifier()
|
key := accountDedupeKey(acc)
|
||||||
if id == "" || existing[id] {
|
if key == "" || existing[key] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c.Accounts = append(c.Accounts, acc)
|
c.Accounts = append(c.Accounts, acc)
|
||||||
existing[id] = true
|
existing[key] = true
|
||||||
importedAccounts++
|
importedAccounts++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,15 +7,30 @@ import (
|
|||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
|
func boolFrom(v any) bool {
|
||||||
|
if v == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch x := v.(type) {
|
||||||
|
case bool:
|
||||||
|
return x
|
||||||
|
case string:
|
||||||
|
return strings.ToLower(strings.TrimSpace(x)) == "true"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, map[string]string, map[string]string, error) {
|
||||||
var (
|
var (
|
||||||
adminCfg *config.AdminConfig
|
adminCfg *config.AdminConfig
|
||||||
runtimeCfg *config.RuntimeConfig
|
runtimeCfg *config.RuntimeConfig
|
||||||
toolcallCfg *config.ToolcallConfig
|
toolcallCfg *config.ToolcallConfig
|
||||||
respCfg *config.ResponsesConfig
|
respCfg *config.ResponsesConfig
|
||||||
embCfg *config.EmbeddingsConfig
|
embCfg *config.EmbeddingsConfig
|
||||||
claudeMap map[string]string
|
autoDeleteCfg *config.AutoDeleteConfig
|
||||||
aliasMap map[string]string
|
claudeMap map[string]string
|
||||||
|
aliasMap map[string]string
|
||||||
)
|
)
|
||||||
|
|
||||||
if raw, ok := req["admin"].(map[string]any); ok {
|
if raw, ok := req["admin"].(map[string]any); ok {
|
||||||
@@ -23,7 +38,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||||
n := intFrom(v)
|
n := intFrom(v)
|
||||||
if n < 1 || n > 720 {
|
if n < 1 || n > 720 {
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||||
}
|
}
|
||||||
cfg.JWTExpireHours = n
|
cfg.JWTExpireHours = n
|
||||||
}
|
}
|
||||||
@@ -35,26 +50,26 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
if v, exists := raw["account_max_inflight"]; exists {
|
if v, exists := raw["account_max_inflight"]; exists {
|
||||||
n := intFrom(v)
|
n := intFrom(v)
|
||||||
if n < 1 || n > 256 {
|
if n < 1 || n > 256 {
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||||
}
|
}
|
||||||
cfg.AccountMaxInflight = n
|
cfg.AccountMaxInflight = n
|
||||||
}
|
}
|
||||||
if v, exists := raw["account_max_queue"]; exists {
|
if v, exists := raw["account_max_queue"]; exists {
|
||||||
n := intFrom(v)
|
n := intFrom(v)
|
||||||
if n < 1 || n > 200000 {
|
if n < 1 || n > 200000 {
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||||
}
|
}
|
||||||
cfg.AccountMaxQueue = n
|
cfg.AccountMaxQueue = n
|
||||||
}
|
}
|
||||||
if v, exists := raw["global_max_inflight"]; exists {
|
if v, exists := raw["global_max_inflight"]; exists {
|
||||||
n := intFrom(v)
|
n := intFrom(v)
|
||||||
if n < 1 || n > 200000 {
|
if n < 1 || n > 200000 {
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||||
}
|
}
|
||||||
cfg.GlobalMaxInflight = n
|
cfg.GlobalMaxInflight = n
|
||||||
}
|
}
|
||||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||||
}
|
}
|
||||||
runtimeCfg = cfg
|
runtimeCfg = cfg
|
||||||
}
|
}
|
||||||
@@ -67,7 +82,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
case "feature_match", "off":
|
case "feature_match", "off":
|
||||||
cfg.Mode = mode
|
cfg.Mode = mode
|
||||||
default:
|
default:
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, exists := raw["early_emit_confidence"]; exists {
|
if v, exists := raw["early_emit_confidence"]; exists {
|
||||||
@@ -76,7 +91,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
case "high", "low", "off":
|
case "high", "low", "off":
|
||||||
cfg.EarlyEmitConfidence = level
|
cfg.EarlyEmitConfidence = level
|
||||||
default:
|
default:
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
toolcallCfg = cfg
|
toolcallCfg = cfg
|
||||||
@@ -87,7 +102,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||||
n := intFrom(v)
|
n := intFrom(v)
|
||||||
if n < 30 || n > 86400 {
|
if n < 30 || n > 86400 {
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||||
}
|
}
|
||||||
cfg.StoreTTLSeconds = n
|
cfg.StoreTTLSeconds = n
|
||||||
}
|
}
|
||||||
@@ -98,9 +113,6 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
cfg := &config.EmbeddingsConfig{}
|
cfg := &config.EmbeddingsConfig{}
|
||||||
if v, exists := raw["provider"]; exists {
|
if v, exists := raw["provider"]; exists {
|
||||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||||
if p == "" {
|
|
||||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
|
|
||||||
}
|
|
||||||
cfg.Provider = p
|
cfg.Provider = p
|
||||||
}
|
}
|
||||||
embCfg = cfg
|
embCfg = cfg
|
||||||
@@ -130,5 +142,13 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
|
if raw, ok := req["auto_delete"].(map[string]any); ok {
|
||||||
|
cfg := &config.AutoDeleteConfig{}
|
||||||
|
if v, exists := raw["sessions"]; exists {
|
||||||
|
cfg.Sessions = boolFrom(v)
|
||||||
|
}
|
||||||
|
autoDeleteCfg = cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, autoDeleteCfg, claudeMap, aliasMap, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
|||||||
"toolcall": snap.Toolcall,
|
"toolcall": snap.Toolcall,
|
||||||
"responses": snap.Responses,
|
"responses": snap.Responses,
|
||||||
"embeddings": snap.Embeddings,
|
"embeddings": snap.Embeddings,
|
||||||
|
"auto_delete": snap.AutoDelete,
|
||||||
"claude_mapping": settingsClaudeMapping(snap),
|
"claude_mapping": settingsClaudeMapping(snap),
|
||||||
"model_aliases": snap.ModelAliases,
|
"model_aliases": snap.ModelAliases,
|
||||||
"env_backed": h.Store.IsEnvBacked(),
|
"env_backed": h.Store.IsEnvBacked(),
|
||||||
|
|||||||
@@ -265,3 +265,57 @@ func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) {
|
|||||||
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
|
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigImportMergeDedupesMobileAliases(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"keys":["k1"],
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"p1"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
merge := map[string]any{
|
||||||
|
"mode": "merge",
|
||||||
|
"config": map[string]any{
|
||||||
|
"accounts": []any{
|
||||||
|
map[string]any{"mobile": "13800138000", "password": "p2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(merge)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.configImport(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 1 {
|
||||||
|
t.Fatalf("expected merge dedupe by canonical mobile, got=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigDedupesMobileAliases(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"keys":["k1"],
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"old"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"accounts": []any{
|
||||||
|
map[string]any{"mobile": "+8613800138000"},
|
||||||
|
map[string]any{"mobile": "13800138000"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(reqBody)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.updateConfig(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
accounts := h.Store.Accounts()
|
||||||
|
if len(accounts) != 1 {
|
||||||
|
t.Fatalf("expected update dedupe by canonical mobile, got=%d", len(accounts))
|
||||||
|
}
|
||||||
|
if accounts[0].Identifier() != "+8613800138000" {
|
||||||
|
t.Fatalf("unexpected identifier: %q", accounts[0].Identifier())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -60,6 +60,9 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
|||||||
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
||||||
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
||||||
}
|
}
|
||||||
|
if autoDeleteCfg != nil {
|
||||||
|
c.AutoDelete.Sessions = autoDeleteCfg.Sessions
|
||||||
|
}
|
||||||
if claudeMap != nil {
|
if claudeMap != nil {
|
||||||
c.ClaudeMapping = claudeMap
|
c.ClaudeMapping = claudeMap
|
||||||
c.ClaudeModelMap = nil
|
c.ClaudeModelMap = nil
|
||||||
|
|||||||
@@ -59,9 +59,11 @@ func toStringSlice(v any) ([]string, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toAccount(m map[string]any) config.Account {
|
func toAccount(m map[string]any) config.Account {
|
||||||
|
email := fieldString(m, "email")
|
||||||
|
mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile"))
|
||||||
return config.Account{
|
return config.Account{
|
||||||
Email: fieldString(m, "email"),
|
Email: email,
|
||||||
Mobile: fieldString(m, "mobile"),
|
Mobile: mobile,
|
||||||
Password: fieldString(m, "password"),
|
Password: fieldString(m, "password"),
|
||||||
Token: fieldString(m, "token"),
|
Token: fieldString(m, "token"),
|
||||||
}
|
}
|
||||||
@@ -90,12 +92,52 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
|||||||
if strings.TrimSpace(acc.Email) == id {
|
if strings.TrimSpace(acc.Email) == id {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(acc.Mobile) == id {
|
if mobileKey := config.CanonicalMobileKey(id); mobileKey != "" && mobileKey == config.CanonicalMobileKey(acc.Mobile) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return acc.Identifier() == id
|
return acc.Identifier() == id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeAccountForStorage(acc config.Account) config.Account {
|
||||||
|
acc.Email = strings.TrimSpace(acc.Email)
|
||||||
|
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
|
||||||
|
return acc
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountDedupeKey(acc config.Account) string {
|
||||||
|
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||||
|
return "email:" + email
|
||||||
|
}
|
||||||
|
if mobile := config.CanonicalMobileKey(acc.Mobile); mobile != "" {
|
||||||
|
return "mobile:" + mobile
|
||||||
|
}
|
||||||
|
if id := strings.TrimSpace(acc.Identifier()); id != "" {
|
||||||
|
return "id:" + id
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAndDedupeAccounts(accounts []config.Account) []config.Account {
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]config.Account, 0, len(accounts))
|
||||||
|
seen := make(map[string]struct{}, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
acc = normalizeAccountForStorage(acc)
|
||||||
|
key := accountDedupeKey(acc)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
out = append(out, acc)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
|
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
|
||||||
id := strings.TrimSpace(identifier)
|
id := strings.TrimSpace(identifier)
|
||||||
if id == "" {
|
if id == "" {
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ func TestToAccountAllFields(t *testing.T) {
|
|||||||
if acc.Email != "user@test.com" {
|
if acc.Email != "user@test.com" {
|
||||||
t.Fatalf("unexpected email: %q", acc.Email)
|
t.Fatalf("unexpected email: %q", acc.Email)
|
||||||
}
|
}
|
||||||
if acc.Mobile != "13800138000" {
|
if acc.Mobile != "+8613800138000" {
|
||||||
t.Fatalf("unexpected mobile: %q", acc.Mobile)
|
t.Fatalf("unexpected mobile: %q", acc.Mobile)
|
||||||
}
|
}
|
||||||
if acc.Password != "secret" {
|
if acc.Password != "secret" {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
@@ -67,20 +68,36 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
|
|||||||
var fixture struct {
|
var fixture struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
ToolNames []string `json:"tool_names"`
|
ToolNames []string `json:"tool_names"`
|
||||||
|
Mode string `json:"mode"`
|
||||||
}
|
}
|
||||||
mustLoadJSON(t, fixturePath, &fixture)
|
mustLoadJSON(t, fixturePath, &fixture)
|
||||||
|
|
||||||
var expected struct {
|
var expected struct {
|
||||||
Calls []util.ParsedToolCall `json:"calls"`
|
Calls []util.ParsedToolCall `json:"calls"`
|
||||||
|
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
|
||||||
|
RejectedByPolicy bool `json:"rejectedByPolicy"`
|
||||||
|
RejectedToolNames []string `json:"rejectedToolNames"`
|
||||||
}
|
}
|
||||||
mustLoadJSON(t, expectedPath, &expected)
|
mustLoadJSON(t, expectedPath, &expected)
|
||||||
|
|
||||||
got := util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
var got util.ToolCallParseResult
|
||||||
if len(got) == 0 && len(expected.Calls) == 0 {
|
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||||
continue
|
case "standalone":
|
||||||
|
got = util.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||||
|
default:
|
||||||
|
got = util.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, expected.Calls) {
|
if got.Calls == nil {
|
||||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls)
|
got.Calls = []util.ParsedToolCall{}
|
||||||
|
}
|
||||||
|
if got.RejectedToolNames == nil {
|
||||||
|
got.RejectedToolNames = []string{}
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
|
||||||
|
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
|
||||||
|
got.RejectedByPolicy != expected.RejectedByPolicy ||
|
||||||
|
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
|
||||||
|
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ func (a Account) Identifier() string {
|
|||||||
if strings.TrimSpace(a.Email) != "" {
|
if strings.TrimSpace(a.Email) != "" {
|
||||||
return strings.TrimSpace(a.Email)
|
return strings.TrimSpace(a.Email)
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(a.Mobile) != "" {
|
if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" {
|
||||||
return strings.TrimSpace(a.Mobile)
|
return mobile
|
||||||
}
|
}
|
||||||
// Backward compatibility: old configs may contain token-only accounts.
|
// Backward compatibility: old configs may contain token-only accounts.
|
||||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
|||||||
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
||||||
m["embeddings"] = c.Embeddings
|
m["embeddings"] = c.Embeddings
|
||||||
}
|
}
|
||||||
|
m["auto_delete"] = c.AutoDelete
|
||||||
if c.VercelSyncHash != "" {
|
if c.VercelSyncHash != "" {
|
||||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||||
}
|
}
|
||||||
@@ -108,6 +109,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
|||||||
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
||||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||||
}
|
}
|
||||||
|
case "auto_delete":
|
||||||
|
if err := json.Unmarshal(v, &c.AutoDelete); err != nil {
|
||||||
|
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||||
|
}
|
||||||
case "_vercel_sync_hash":
|
case "_vercel_sync_hash":
|
||||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||||
@@ -141,6 +146,7 @@ func (c Config) Clone() Config {
|
|||||||
Toolcall: c.Toolcall,
|
Toolcall: c.Toolcall,
|
||||||
Responses: c.Responses,
|
Responses: c.Responses,
|
||||||
Embeddings: c.Embeddings,
|
Embeddings: c.Embeddings,
|
||||||
|
AutoDelete: c.AutoDelete,
|
||||||
VercelSyncHash: c.VercelSyncHash,
|
VercelSyncHash: c.VercelSyncHash,
|
||||||
VercelSyncTime: c.VercelSyncTime,
|
VercelSyncTime: c.VercelSyncTime,
|
||||||
AdditionalFields: map[string]any{},
|
AdditionalFields: map[string]any{},
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ type Config struct {
|
|||||||
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
|
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
|
||||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||||
|
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||||
AdditionalFields map[string]any `json:"-"`
|
AdditionalFields map[string]any `json:"-"`
|
||||||
}
|
}
|
||||||
@@ -53,3 +54,7 @@ type ResponsesConfig struct {
|
|||||||
type EmbeddingsConfig struct {
|
type EmbeddingsConfig struct {
|
||||||
Provider string `json:"provider,omitempty"`
|
Provider string `json:"provider,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AutoDeleteConfig struct {
|
||||||
|
Sessions bool `json:"sessions"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ func TestConfigCloneNilMaps(t *testing.T) {
|
|||||||
|
|
||||||
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
|
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
|
||||||
acc := Account{Mobile: "13800138000", Token: "tok"}
|
acc := Account{Mobile: "13800138000", Token: "tok"}
|
||||||
if acc.Identifier() != "13800138000" {
|
if acc.Identifier() != "+8613800138000" {
|
||||||
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
|
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
82
internal/config/mobile.go
Normal file
82
internal/config/mobile.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// NormalizeMobileForStorage normalizes user input to a stable storage format.
|
||||||
|
// It keeps existing country codes and auto-prefixes mainland China numbers with +86.
|
||||||
|
func NormalizeMobileForStorage(raw string) string {
|
||||||
|
digits, hasPlus := extractMobileDigits(raw)
|
||||||
|
if digits == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if hasPlus {
|
||||||
|
return "+" + digits
|
||||||
|
}
|
||||||
|
if isChinaMobileWithCountryCode(digits) {
|
||||||
|
return "+86" + digits[2:]
|
||||||
|
}
|
||||||
|
if isChinaMainlandMobileDigits(digits) {
|
||||||
|
return "+86" + digits
|
||||||
|
}
|
||||||
|
// For non-China numbers without a leading +, preserve semantics by adding it.
|
||||||
|
return "+" + digits
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalMobileKey returns the comparison key used by dedupe/matching logic.
|
||||||
|
func CanonicalMobileKey(raw string) string {
|
||||||
|
return NormalizeMobileForStorage(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractMobileDigits(raw string) (digits string, hasPlus bool) {
|
||||||
|
s := strings.TrimSpace(raw)
|
||||||
|
if s == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
goto collect
|
||||||
|
case isMobileSeparator(r):
|
||||||
|
continue
|
||||||
|
case r == '+':
|
||||||
|
hasPlus = true
|
||||||
|
goto collect
|
||||||
|
default:
|
||||||
|
goto collect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
collect:
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s))
|
||||||
|
for _, r := range s {
|
||||||
|
if r >= '0' && r <= '9' {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String(), hasPlus
|
||||||
|
}
|
||||||
|
|
||||||
|
func isChinaMainlandMobileDigits(digits string) bool {
|
||||||
|
if len(digits) != 11 || digits[0] != '1' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return digits[1] >= '3' && digits[1] <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isChinaMobileWithCountryCode(digits string) bool {
|
||||||
|
if len(digits) != 13 || !strings.HasPrefix(digits, "86") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isChinaMainlandMobileDigits(digits[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMobileSeparator(r rune) bool {
|
||||||
|
switch r {
|
||||||
|
case ' ', '\t', '\n', '\r', '-', '(', ')', '.', '/':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
36
internal/config/mobile_test.go
Normal file
36
internal/config/mobile_test.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeMobileForStorageChinaMainlandAddsPlus86(t *testing.T) {
|
||||||
|
if got := NormalizeMobileForStorage("13800138000"); got != "+8613800138000" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForStorageChinaWithCountryCode(t *testing.T) {
|
||||||
|
if got := NormalizeMobileForStorage("8613800138000"); got != "+8613800138000" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForStorageKeepsExistingCountryCode(t *testing.T) {
|
||||||
|
if got := NormalizeMobileForStorage(" +1 (415) 555-2671 "); got != "+14155552671" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanonicalMobileKeyMatchesChinaAliases(t *testing.T) {
|
||||||
|
a := CanonicalMobileKey("+8613800138000")
|
||||||
|
b := CanonicalMobileKey("13800138000")
|
||||||
|
c := CanonicalMobileKey("86 13800138000")
|
||||||
|
if a == "" || a != b || b != c {
|
||||||
|
t.Fatalf("alias mismatch: a=%q b=%q c=%q", a, b, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanonicalMobileKeyEmptyForInvalidInput(t *testing.T) {
|
||||||
|
if got := CanonicalMobileKey("() --"); got != "" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -165,3 +165,9 @@ func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
|
|||||||
}
|
}
|
||||||
return defaultSize
|
return defaultSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) AutoDeleteSessions() bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.cfg.AutoDelete.Sessions
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"ds2api/internal/auth"
|
"ds2api/internal/auth"
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
@@ -20,8 +21,9 @@ func (c *Client) Login(ctx context.Context, acc config.Account) (string, error)
|
|||||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||||
payload["email"] = email
|
payload["email"] = email
|
||||||
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
||||||
payload["mobile"] = mobile
|
loginMobile, areaCode := normalizeMobileForLogin(mobile)
|
||||||
payload["area_code"] = nil
|
payload["mobile"] = loginMobile
|
||||||
|
payload["area_code"] = areaCode
|
||||||
} else {
|
} else {
|
||||||
return "", errors.New("missing email/mobile")
|
return "", errors.New("missing email/mobile")
|
||||||
}
|
}
|
||||||
@@ -60,8 +62,8 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
|||||||
attempts++
|
attempts++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
code := intFrom(resp["code"])
|
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||||
if status == http.StatusOK && code == 0 {
|
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||||
data, _ := resp["data"].(map[string]any)
|
data, _ := resp["data"].(map[string]any)
|
||||||
bizData, _ := data["biz_data"].(map[string]any)
|
bizData, _ := data["biz_data"].(map[string]any)
|
||||||
sessionID, _ := bizData["id"].(string)
|
sessionID, _ := bizData["id"].(string)
|
||||||
@@ -69,10 +71,9 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
|||||||
return sessionID, nil
|
return sessionID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg, _ := resp["msg"].(string)
|
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
|
||||||
if a.UseConfigToken {
|
if a.UseConfigToken {
|
||||||
if isTokenInvalid(status, code, msg) && !refreshed {
|
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||||
if c.Auth.RefreshToken(ctx, a) {
|
if c.Auth.RefreshToken(ctx, a) {
|
||||||
refreshed = true
|
refreshed = true
|
||||||
continue
|
continue
|
||||||
@@ -94,6 +95,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
|||||||
maxAttempts = c.maxRetries
|
maxAttempts = c.maxRetries
|
||||||
}
|
}
|
||||||
attempts := 0
|
attempts := 0
|
||||||
|
refreshed := false
|
||||||
for attempts < maxAttempts {
|
for attempts < maxAttempts {
|
||||||
headers := c.authHeaders(a.DeepSeekToken)
|
headers := c.authHeaders(a.DeepSeekToken)
|
||||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||||
@@ -102,8 +104,8 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
|||||||
attempts++
|
attempts++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
code := intFrom(resp["code"])
|
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||||
if status == http.StatusOK && code == 0 {
|
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||||
data, _ := resp["data"].(map[string]any)
|
data, _ := resp["data"].(map[string]any)
|
||||||
bizData, _ := data["biz_data"].(map[string]any)
|
bizData, _ := data["biz_data"].(map[string]any)
|
||||||
challenge, _ := bizData["challenge"].(map[string]any)
|
challenge, _ := bizData["challenge"].(map[string]any)
|
||||||
@@ -114,15 +116,16 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
|||||||
}
|
}
|
||||||
return BuildPowHeader(challenge, answer)
|
return BuildPowHeader(challenge, answer)
|
||||||
}
|
}
|
||||||
msg, _ := resp["msg"].(string)
|
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
|
||||||
if a.UseConfigToken {
|
if a.UseConfigToken {
|
||||||
if isTokenInvalid(status, code, msg) {
|
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||||
if c.Auth.RefreshToken(ctx, a) {
|
if c.Auth.RefreshToken(ctx, a) {
|
||||||
|
refreshed = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.Auth.SwitchAccount(ctx, a) {
|
if c.Auth.SwitchAccount(ctx, a) {
|
||||||
|
refreshed = false
|
||||||
attempts++
|
attempts++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -141,13 +144,55 @@ func (c *Client) authHeaders(token string) map[string]string {
|
|||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTokenInvalid(status int, code int, msg string) bool {
|
func isTokenInvalid(status int, code int, bizCode int, msg string, bizMsg string) bool {
|
||||||
msg = strings.ToLower(msg)
|
msg = strings.ToLower(strings.TrimSpace(msg) + " " + strings.TrimSpace(bizMsg))
|
||||||
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if code == 40001 || code == 40002 || code == 40003 {
|
if code == 40001 || code == 40002 || code == 40003 || bizCode == 40001 || bizCode == 40002 || bizCode == 40003 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
return strings.Contains(msg, "token") ||
|
||||||
|
strings.Contains(msg, "unauthorized") ||
|
||||||
|
strings.Contains(msg, "expired") ||
|
||||||
|
strings.Contains(msg, "not login") ||
|
||||||
|
strings.Contains(msg, "login required") ||
|
||||||
|
strings.Contains(msg, "invalid jwt")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractResponseStatus(resp map[string]any) (code int, bizCode int, msg string, bizMsg string) {
|
||||||
|
code = intFrom(resp["code"])
|
||||||
|
msg, _ = resp["msg"].(string)
|
||||||
|
data, _ := resp["data"].(map[string]any)
|
||||||
|
bizCode = intFrom(data["biz_code"])
|
||||||
|
bizMsg, _ = data["biz_msg"].(string)
|
||||||
|
if strings.TrimSpace(bizMsg) == "" {
|
||||||
|
if bizData, ok := data["biz_data"].(map[string]any); ok {
|
||||||
|
bizMsg, _ = bizData["msg"].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return code, bizCode, msg, bizMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMobileForLogin(raw string) (mobile string, areaCode any) {
|
||||||
|
s := strings.TrimSpace(raw)
|
||||||
|
if s == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
hasPlus := strings.HasPrefix(s, "+")
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s))
|
||||||
|
for _, r := range s {
|
||||||
|
if unicode.IsDigit(r) {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
digits := b.String()
|
||||||
|
if digits == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if (hasPlus || strings.HasPrefix(digits, "86")) && strings.HasPrefix(digits, "86") && len(digits) == 13 {
|
||||||
|
return digits[2:], nil
|
||||||
|
}
|
||||||
|
return digits, nil
|
||||||
}
|
}
|
||||||
|
|||||||
33
internal/deepseek/client_auth_mobile_test.go
Normal file
33
internal/deepseek/client_auth_mobile_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package deepseek
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeMobileForLogin_ChinaWithPlus86(t *testing.T) {
|
||||||
|
mobile, areaCode := normalizeMobileForLogin("+8613800138000")
|
||||||
|
if mobile != "13800138000" {
|
||||||
|
t.Fatalf("unexpected mobile: %q", mobile)
|
||||||
|
}
|
||||||
|
if areaCode != nil {
|
||||||
|
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForLogin_ChinaWith86Prefix(t *testing.T) {
|
||||||
|
mobile, areaCode := normalizeMobileForLogin("8613800138000")
|
||||||
|
if mobile != "13800138000" {
|
||||||
|
t.Fatalf("unexpected mobile: %q", mobile)
|
||||||
|
}
|
||||||
|
if areaCode != nil {
|
||||||
|
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForLogin_KeepPlainDigits(t *testing.T) {
|
||||||
|
mobile, areaCode := normalizeMobileForLogin("13800138000")
|
||||||
|
if mobile != "13800138000" {
|
||||||
|
t.Fatalf("unexpected mobile: %q", mobile)
|
||||||
|
}
|
||||||
|
if areaCode != nil {
|
||||||
|
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -62,3 +62,51 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
|||||||
}
|
}
|
||||||
return out, resp.StatusCode, nil
|
return out, resp.StatusCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) getJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, error) {
|
||||||
|
body, status, err := c.getJSONWithStatus(ctx, doer, url, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if status == 0 {
|
||||||
|
return nil, errors.New("request failed")
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, int, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
resp, err := doer.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
config.Logger.Warn("[deepseek] fingerprint GET request failed, fallback to std transport", "url", url, "error", err)
|
||||||
|
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if reqErr != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
req2.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
resp, err = c.fallback.Do(req2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
payloadBytes, err := readResponseBody(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp.StatusCode, err
|
||||||
|
}
|
||||||
|
out := map[string]any{}
|
||||||
|
if len(payloadBytes) > 0 {
|
||||||
|
if err := json.Unmarshal(payloadBytes, &out); err != nil {
|
||||||
|
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|||||||
256
internal/deepseek/client_session.go
Normal file
256
internal/deepseek/client_session.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package deepseek
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionInfo 会话信息
|
||||||
|
type SessionInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
TitleType string `json:"title_type"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
UpdatedAt float64 `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStats 会话统计结果
|
||||||
|
type SessionStats struct {
|
||||||
|
AccountID string // 账号标识 (email 或 mobile)
|
||||||
|
FirstPageCount int // 第一页会话数量(当 HasMore 为 true 时,真实总数可能更大)
|
||||||
|
PinnedCount int // 置顶会话数量
|
||||||
|
HasMore bool // 是否还有更多页
|
||||||
|
Success bool // 请求是否成功
|
||||||
|
ErrorMessage string // 错误信息
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionCount 获取单个账号的会话数量
|
||||||
|
func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (*SessionStats, error) {
|
||||||
|
if maxAttempts <= 0 {
|
||||||
|
maxAttempts = c.maxRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := &SessionStats{
|
||||||
|
AccountID: a.AccountID,
|
||||||
|
}
|
||||||
|
|
||||||
|
attempts := 0
|
||||||
|
refreshed := false
|
||||||
|
|
||||||
|
for attempts < maxAttempts {
|
||||||
|
headers := c.authHeaders(a.DeepSeekToken)
|
||||||
|
|
||||||
|
// 构建请求 URL
|
||||||
|
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||||
|
|
||||||
|
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||||
|
if err != nil {
|
||||||
|
config.Logger.Warn("[get_session_count] request error", "error", err, "account", a.AccountID)
|
||||||
|
attempts++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||||
|
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||||
|
data, _ := resp["data"].(map[string]any)
|
||||||
|
bizData, _ := data["biz_data"].(map[string]any)
|
||||||
|
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||||
|
hasMore, _ := bizData["has_more"].(bool)
|
||||||
|
|
||||||
|
stats.FirstPageCount = len(chatSessions)
|
||||||
|
stats.HasMore = hasMore
|
||||||
|
stats.Success = true
|
||||||
|
|
||||||
|
// 统计置顶会话数量
|
||||||
|
for _, session := range chatSessions {
|
||||||
|
if s, ok := session.(map[string]any); ok {
|
||||||
|
if pinned, ok := s["pinned"].(bool); ok && pinned {
|
||||||
|
stats.PinnedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
config.Logger.Warn("[get_session_count] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "account", a.AccountID)
|
||||||
|
|
||||||
|
if a.UseConfigToken {
|
||||||
|
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||||
|
if c.Auth.RefreshToken(ctx, a) {
|
||||||
|
refreshed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Auth.SwitchAccount(ctx, a) {
|
||||||
|
refreshed = false
|
||||||
|
attempts++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attempts++
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.Success = false
|
||||||
|
stats.ErrorMessage = "get session count failed after retries"
|
||||||
|
return stats, errors.New(stats.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionCountForToken 直接使用 token 获取会话数量(直通模式)
|
||||||
|
func (c *Client) GetSessionCountForToken(ctx context.Context, token string) (*SessionStats, error) {
|
||||||
|
headers := c.authHeaders(token)
|
||||||
|
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||||
|
|
||||||
|
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||||
|
if status != http.StatusOK || code != 0 || bizCode != 0 {
|
||||||
|
if strings.TrimSpace(bizMsg) != "" {
|
||||||
|
msg = bizMsg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := resp["data"].(map[string]any)
|
||||||
|
bizData, _ := data["biz_data"].(map[string]any)
|
||||||
|
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||||
|
hasMore, _ := bizData["has_more"].(bool)
|
||||||
|
|
||||||
|
stats := &SessionStats{
|
||||||
|
FirstPageCount: len(chatSessions),
|
||||||
|
HasMore: hasMore,
|
||||||
|
Success: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计置顶会话数量
|
||||||
|
for _, session := range chatSessions {
|
||||||
|
if s, ok := session.(map[string]any); ok {
|
||||||
|
if pinned, ok := s["pinned"].(bool); ok && pinned {
|
||||||
|
stats.PinnedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionCountAll 获取所有账号的会话数量统计
|
||||||
|
func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
|
||||||
|
accounts := c.Store.Accounts()
|
||||||
|
results := make([]*SessionStats, 0, len(accounts))
|
||||||
|
|
||||||
|
for _, acc := range accounts {
|
||||||
|
token := acc.Token
|
||||||
|
accountID := acc.Email
|
||||||
|
if accountID == "" {
|
||||||
|
accountID = acc.Mobile
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 token,尝试登录获取
|
||||||
|
if token == "" {
|
||||||
|
var err error
|
||||||
|
token, err = c.Login(ctx, acc)
|
||||||
|
if err != nil {
|
||||||
|
results = append(results, &SessionStats{
|
||||||
|
AccountID: accountID,
|
||||||
|
Success: false,
|
||||||
|
ErrorMessage: fmt.Sprintf("login failed: %v", err),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := c.GetSessionCountForToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
results = append(results, &SessionStats{
|
||||||
|
AccountID: accountID,
|
||||||
|
Success: false,
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.AccountID = accountID
|
||||||
|
results = append(results, stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchSessionPage 获取会话列表(支持分页)
|
||||||
|
func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, cursor string) ([]SessionInfo, bool, error) {
|
||||||
|
headers := c.authHeaders(a.DeepSeekToken)
|
||||||
|
|
||||||
|
// 构建请求 URL
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("lte_cursor.pinned", "false")
|
||||||
|
if cursor != "" {
|
||||||
|
params.Set("lte_cursor", cursor)
|
||||||
|
}
|
||||||
|
reqURL := DeepSeekFetchSessionURL + "?" + params.Encode()
|
||||||
|
|
||||||
|
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
code := intFrom(resp["code"])
|
||||||
|
if status != http.StatusOK || code != 0 {
|
||||||
|
msg, _ := resp["msg"].(string)
|
||||||
|
return nil, false, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := resp["data"].(map[string]any)
|
||||||
|
bizData, _ := data["biz_data"].(map[string]any)
|
||||||
|
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||||
|
hasMore, _ := bizData["has_more"].(bool)
|
||||||
|
|
||||||
|
sessions := make([]SessionInfo, 0, len(chatSessions))
|
||||||
|
for _, s := range chatSessions {
|
||||||
|
if m, ok := s.(map[string]any); ok {
|
||||||
|
session := SessionInfo{
|
||||||
|
ID: stringFromMap(m, "id"),
|
||||||
|
Title: stringFromMap(m, "title"),
|
||||||
|
TitleType: stringFromMap(m, "title_type"),
|
||||||
|
Pinned: boolFromMap(m, "pinned"),
|
||||||
|
UpdatedAt: floatFromMap(m, "updated_at"),
|
||||||
|
}
|
||||||
|
sessions = append(sessions, session)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions, hasMore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数
|
||||||
|
func stringFromMap(m map[string]any, key string) string {
|
||||||
|
if v, ok := m[key].(string); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolFromMap(m map[string]any, key string) bool {
|
||||||
|
if v, ok := m[key].(bool); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func floatFromMap(m map[string]any, key string) float64 {
|
||||||
|
if v, ok := m[key].(float64); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
155
internal/deepseek/client_session_delete.go
Normal file
155
internal/deepseek/client_session_delete.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package deepseek
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteSessionResult 删除会话结果
|
||||||
|
type DeleteSessionResult struct {
|
||||||
|
SessionID string // 会话 ID
|
||||||
|
Success bool // 是否成功
|
||||||
|
ErrorMessage string // 错误信息
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession 删除单个会话
|
||||||
|
func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, sessionID string, maxAttempts int) (*DeleteSessionResult, error) {
|
||||||
|
if maxAttempts <= 0 {
|
||||||
|
maxAttempts = c.maxRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &DeleteSessionResult{
|
||||||
|
SessionID: sessionID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionID == "" {
|
||||||
|
result.ErrorMessage = "session_id is required"
|
||||||
|
return result, errors.New(result.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
attempts := 0
|
||||||
|
refreshed := false
|
||||||
|
|
||||||
|
for attempts < maxAttempts {
|
||||||
|
headers := c.authHeaders(a.DeepSeekToken)
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"chat_session_id": sessionID,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||||
|
if err != nil {
|
||||||
|
config.Logger.Warn("[delete_session] request error", "error", err, "session_id", sessionID)
|
||||||
|
attempts++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||||
|
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||||
|
result.Success = true
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
config.Logger.Warn("[delete_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "session_id", sessionID)
|
||||||
|
|
||||||
|
if a.UseConfigToken {
|
||||||
|
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||||
|
if c.Auth.RefreshToken(ctx, a) {
|
||||||
|
refreshed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Auth.SwitchAccount(ctx, a) {
|
||||||
|
refreshed = false
|
||||||
|
attempts++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attempts++
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Success = false
|
||||||
|
result.ErrorMessage = "delete session failed after retries"
|
||||||
|
return result, errors.New(result.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSessionForToken 直接使用 token 删除会话(直通模式)
|
||||||
|
func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*DeleteSessionResult, error) {
|
||||||
|
result := &DeleteSessionResult{
|
||||||
|
SessionID: sessionID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionID == "" {
|
||||||
|
result.ErrorMessage = "session_id is required"
|
||||||
|
return result, errors.New(result.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := c.authHeaders(token)
|
||||||
|
payload := map[string]any{
|
||||||
|
"chat_session_id": sessionID,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||||
|
if err != nil {
|
||||||
|
result.ErrorMessage = err.Error()
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
code := intFrom(resp["code"])
|
||||||
|
if status != http.StatusOK || code != 0 {
|
||||||
|
msg, _ := resp["msg"].(string)
|
||||||
|
result.ErrorMessage = fmt.Sprintf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
return result, errors.New(result.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Success = true
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllSessions 删除所有会话(谨慎使用)
|
||||||
|
func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) error {
|
||||||
|
headers := c.authHeaders(a.DeepSeekToken)
|
||||||
|
payload := map[string]any{}
|
||||||
|
|
||||||
|
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||||
|
if err != nil {
|
||||||
|
config.Logger.Warn("[delete_all_sessions] request error", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
code := intFrom(resp["code"])
|
||||||
|
if status != http.StatusOK || code != 0 {
|
||||||
|
msg, _ := resp["msg"].(string)
|
||||||
|
config.Logger.Warn("[delete_all_sessions] failed", "status", status, "code", code, "msg", msg)
|
||||||
|
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllSessionsForToken 直接使用 token 删除所有会话(直通模式)
|
||||||
|
func (c *Client) DeleteAllSessionsForToken(ctx context.Context, token string) error {
|
||||||
|
headers := c.authHeaders(token)
|
||||||
|
payload := map[string]any{}
|
||||||
|
|
||||||
|
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||||
|
if err != nil {
|
||||||
|
config.Logger.Warn("[delete_all_sessions_for_token] request error", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
code := intFrom(resp["code"])
|
||||||
|
if status != http.StatusOK || code != 0 {
|
||||||
|
msg, _ := resp["msg"].(string)
|
||||||
|
config.Logger.Warn("[delete_all_sessions_for_token] failed", "status", status, "code", code, "msg", msg)
|
||||||
|
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -11,6 +11,9 @@ const (
|
|||||||
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
|
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
|
||||||
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
||||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||||
|
DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page"
|
||||||
|
DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete"
|
||||||
|
DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultBaseHeaders = map[string]string{
|
var defaultBaseHeaders = map[string]string{
|
||||||
|
|||||||
@@ -8,15 +8,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
|
||||||
finishReason := "stop"
|
finishReason := "stop"
|
||||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||||
if strings.TrimSpace(finalThinking) != "" {
|
if strings.TrimSpace(finalThinking) != "" {
|
||||||
messageObj["reasoning_content"] = finalThinking
|
messageObj["reasoning_content"] = finalThinking
|
||||||
}
|
}
|
||||||
if len(detected) > 0 {
|
if len(detected.Calls) > 0 {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected.Calls)
|
||||||
messageObj["content"] = nil
|
messageObj["content"] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,17 +11,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||||
// Align responses tool-call semantics with chat/completions:
|
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
// as executable tool calls.
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
|
||||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
|
||||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
|
||||||
}
|
|
||||||
exposedOutputText := finalText
|
exposedOutputText := finalText
|
||||||
output := make([]any, 0, 2)
|
output := make([]any, 0, 2)
|
||||||
if len(detected) > 0 {
|
if len(detected.Calls) > 0 {
|
||||||
exposedOutputText = ""
|
exposedOutputText = ""
|
||||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
output = append(output, toResponsesFunctionCallItems(detected.Calls)...)
|
||||||
} else {
|
} else {
|
||||||
content := make([]any, 0, 2)
|
content := make([]any, 0, 2)
|
||||||
if finalThinking != "" {
|
if finalThinking != "" {
|
||||||
|
|||||||
@@ -71,6 +71,19 @@ func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, cont
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func BuildResponsesTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "response.output_text.done",
|
||||||
|
"id": responseID,
|
||||||
|
"response_id": responseID,
|
||||||
|
"item_id": itemID,
|
||||||
|
"output_index": outputIndex,
|
||||||
|
"content_index": contentIndex,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"type": "response.reasoning.delta",
|
"type": "response.reasoning.delta",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
|
func TestBuildResponseObjectPromotesMixedProseToolPayloadToFunctionCall(t *testing.T) {
|
||||||
obj := BuildResponseObject(
|
obj := BuildResponseObject(
|
||||||
"resp_test",
|
"resp_test",
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
@@ -57,20 +57,19 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
|
|||||||
|
|
||||||
outputText, _ := obj["output_text"].(string)
|
outputText, _ := obj["output_text"].(string)
|
||||||
if outputText != "" {
|
if outputText != "" {
|
||||||
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
|
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||||
}
|
}
|
||||||
|
|
||||||
output, _ := obj["output"].([]any)
|
output, _ := obj["output"].([]any)
|
||||||
if len(output) != 1 {
|
if len(output) != 1 {
|
||||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
t.Fatalf("expected one function_call output item, got %#v", obj["output"])
|
||||||
}
|
}
|
||||||
first, _ := output[0].(map[string]any)
|
first, _ := output[0].(map[string]any)
|
||||||
if first["type"] != "function_call" {
|
if first["type"] != "function_call" {
|
||||||
t.Fatalf("expected first output type function_call, got %#v", first["type"])
|
t.Fatalf("expected function_call output type, got %#v", first["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
|
func TestBuildResponseObjectPromotesFencedToolPayloadToFunctionCall(t *testing.T) {
|
||||||
obj := BuildResponseObject(
|
obj := BuildResponseObject(
|
||||||
"resp_test",
|
"resp_test",
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
@@ -81,16 +80,16 @@ func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
outputText, _ := obj["output_text"].(string)
|
outputText, _ := obj["output_text"].(string)
|
||||||
if outputText == "" {
|
if outputText != "" {
|
||||||
t.Fatalf("expected output_text preserved for fenced example")
|
t.Fatalf("expected output_text hidden for fenced tool payload, got %q", outputText)
|
||||||
}
|
}
|
||||||
output, _ := obj["output"].([]any)
|
output, _ := obj["output"].([]any)
|
||||||
if len(output) != 1 {
|
if len(output) != 1 {
|
||||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
t.Fatalf("expected one function_call output item, got %#v", obj["output"])
|
||||||
}
|
}
|
||||||
first, _ := output[0].(map[string]any)
|
first, _ := output[0].(map[string]any)
|
||||||
if first["type"] != "message" {
|
if first["type"] != "function_call" {
|
||||||
t.Fatalf("expected message output type, got %#v", first["type"])
|
t.Fatalf("expected function_call output type, got %#v", first["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,7 +126,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
func TestBuildResponseObjectIgnoresToolCallFromThinkingChannel(t *testing.T) {
|
||||||
obj := BuildResponseObject(
|
obj := BuildResponseObject(
|
||||||
"resp_test",
|
"resp_test",
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
@@ -139,10 +138,10 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
|||||||
|
|
||||||
output, _ := obj["output"].([]any)
|
output, _ := obj["output"].([]any)
|
||||||
if len(output) != 1 {
|
if len(output) != 1 {
|
||||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||||
}
|
}
|
||||||
first, _ := output[0].(map[string]any)
|
first, _ := output[0].(map[string]any)
|
||||||
if first["type"] != "function_call" {
|
if first["type"] != "message" {
|
||||||
t.Fatalf("expected output function_call, got %#v", first["type"])
|
t.Fatalf("expected output message, got %#v", first["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ const {
|
|||||||
} = require('./sse_parse');
|
} = require('./sse_parse');
|
||||||
const {
|
const {
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
|
formatIncrementalToolCallDeltas,
|
||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
} = require('./toolcall_policy');
|
} = require('./toolcall_policy');
|
||||||
const {
|
const {
|
||||||
estimateTokens,
|
estimateTokens,
|
||||||
@@ -82,7 +84,9 @@ module.exports.__test = {
|
|||||||
shouldSkipPath,
|
shouldSkipPath,
|
||||||
asString,
|
asString,
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
|
formatIncrementalToolCallDeltas,
|
||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
estimateTokens,
|
estimateTokens,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -68,6 +68,47 @@ function formatIncrementalToolCallDeltas(deltas, idStore) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenNames) {
|
||||||
|
if (!Array.isArray(deltas) || deltas.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const seen = seenNames instanceof Map ? seenNames : new Map();
|
||||||
|
const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== ''));
|
||||||
|
if (allowed.size === 0) {
|
||||||
|
for (const d of deltas) {
|
||||||
|
if (d && typeof d === 'object' && asString(d.name)) {
|
||||||
|
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||||
|
seen.set(index, '__blocked__');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const out = [];
|
||||||
|
for (const d of deltas) {
|
||||||
|
if (!d || typeof d !== 'object') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||||
|
const name = asString(d.name);
|
||||||
|
if (name) {
|
||||||
|
if (!allowed.has(name)) {
|
||||||
|
seen.set(index, '__blocked__');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.set(index, name);
|
||||||
|
out.push(d);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const existing = asString(seen.get(index));
|
||||||
|
if (!existing || existing === '__blocked__') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out.push(d);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
function ensureStreamToolCallID(idStore, index) {
|
function ensureStreamToolCallID(idStore, index) {
|
||||||
const key = Number.isInteger(index) ? index : 0;
|
const key = Number.isInteger(index) ? index : 0;
|
||||||
const existing = idStore.get(key);
|
const existing = idStore.get(key);
|
||||||
@@ -104,4 +145,5 @@ module.exports = {
|
|||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
formatIncrementalToolCallDeltas,
|
formatIncrementalToolCallDeltas,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ const {
|
|||||||
createToolSieveState,
|
createToolSieveState,
|
||||||
processToolSieveChunk,
|
processToolSieveChunk,
|
||||||
flushToolSieve,
|
flushToolSieve,
|
||||||
parseToolCalls,
|
parseStandaloneToolCalls,
|
||||||
formatOpenAIStreamToolCalls,
|
formatOpenAIStreamToolCalls,
|
||||||
} = require('../helpers/stream-tool-sieve');
|
} = require('../helpers/stream-tool-sieve');
|
||||||
const {
|
const {
|
||||||
@@ -24,7 +24,6 @@ const {
|
|||||||
} = require('./token_usage');
|
} = require('./token_usage');
|
||||||
const {
|
const {
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
formatIncrementalToolCallDeltas,
|
|
||||||
} = require('./toolcall_policy');
|
} = require('./toolcall_policy');
|
||||||
const {
|
const {
|
||||||
createChatCompletionEmitter,
|
createChatCompletionEmitter,
|
||||||
@@ -130,7 +129,6 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
|||||||
let thinkingText = '';
|
let thinkingText = '';
|
||||||
let outputText = '';
|
let outputText = '';
|
||||||
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
||||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
|
||||||
const toolSieveState = createToolSieveState();
|
const toolSieveState = createToolSieveState();
|
||||||
let toolCallsEmitted = false;
|
let toolCallsEmitted = false;
|
||||||
const streamToolCallIDs = new Map();
|
const streamToolCallIDs = new Map();
|
||||||
@@ -155,13 +153,18 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
|||||||
await releaseLease();
|
await releaseLease();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const detected = parseToolCalls(outputText, toolNames);
|
const detected = parseStandaloneToolCalls(outputText, toolNames);
|
||||||
if (detected.length > 0 && !toolCallsEmitted) {
|
if (detected.length > 0 && !toolCallsEmitted) {
|
||||||
toolCallsEmitted = true;
|
toolCallsEmitted = true;
|
||||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
|
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected, streamToolCallIDs) });
|
||||||
} else if (toolSieveEnabled) {
|
} else if (toolSieveEnabled) {
|
||||||
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
||||||
for (const evt of tailEvents) {
|
for (const evt of tailEvents) {
|
||||||
|
if (evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0) {
|
||||||
|
toolCallsEmitted = true;
|
||||||
|
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (evt.text) {
|
if (evt.text) {
|
||||||
sendDeltaFrame({ content: evt.text });
|
sendDeltaFrame({ content: evt.text });
|
||||||
}
|
}
|
||||||
@@ -252,17 +255,9 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
|||||||
}
|
}
|
||||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||||
for (const evt of events) {
|
for (const evt of events) {
|
||||||
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
|
|
||||||
if (!emitEarlyToolDeltas) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
toolCallsEmitted = true;
|
|
||||||
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (evt.type === 'tool_calls') {
|
if (evt.type === 'tool_calls') {
|
||||||
toolCallsEmitted = true;
|
toolCallsEmitted = true;
|
||||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
|
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (evt.text) {
|
if (evt.text) {
|
||||||
|
|||||||
@@ -2,13 +2,13 @@
|
|||||||
|
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
|
|
||||||
function formatOpenAIStreamToolCalls(calls) {
|
function formatOpenAIStreamToolCalls(calls, idStore) {
|
||||||
if (!Array.isArray(calls) || calls.length === 0) {
|
if (!Array.isArray(calls) || calls.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
return calls.map((c, idx) => ({
|
return calls.map((c, idx) => ({
|
||||||
index: idx,
|
index: idx,
|
||||||
id: `call_${newCallID()}`,
|
id: ensureStreamToolCallID(idStore, idx),
|
||||||
type: 'function',
|
type: 'function',
|
||||||
function: {
|
function: {
|
||||||
name: c.name,
|
name: c.name,
|
||||||
@@ -17,6 +17,20 @@ function formatOpenAIStreamToolCalls(calls) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function ensureStreamToolCallID(idStore, index) {
|
||||||
|
if (!(idStore instanceof Map)) {
|
||||||
|
return `call_${newCallID()}`;
|
||||||
|
}
|
||||||
|
const key = Number.isInteger(index) ? index : 0;
|
||||||
|
const existing = idStore.get(key);
|
||||||
|
if (existing) {
|
||||||
|
return existing;
|
||||||
|
}
|
||||||
|
const next = `call_${newCallID()}`;
|
||||||
|
idStore.set(key, next);
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
|
||||||
function newCallID() {
|
function newCallID() {
|
||||||
if (typeof crypto.randomUUID === 'function') {
|
if (typeof crypto.randomUUID === 'function') {
|
||||||
return crypto.randomUUID().replace(/-/g, '');
|
return crypto.randomUUID().replace(/-/g, '');
|
||||||
|
|||||||
@@ -1,226 +0,0 @@
|
|||||||
'use strict';
|
|
||||||
|
|
||||||
const {
|
|
||||||
looksLikeToolExampleContext,
|
|
||||||
insideCodeFence,
|
|
||||||
} = require('./state');
|
|
||||||
const {
|
|
||||||
findObjectFieldValueStart,
|
|
||||||
parseJSONStringLiteral,
|
|
||||||
skipSpaces,
|
|
||||||
} = require('./jsonscan');
|
|
||||||
|
|
||||||
function buildIncrementalToolDeltas(state) {
|
|
||||||
const captured = state.capture || '';
|
|
||||||
if (!captured) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (looksLikeToolExampleContext(state.recentTextTail)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const lower = captured.toLowerCase();
|
|
||||||
const keyIdx = lower.indexOf('tool_calls');
|
|
||||||
if (keyIdx < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
|
||||||
if (start < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const callStart = findFirstToolCallObjectStart(captured, keyIdx);
|
|
||||||
if (callStart < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
|
|
||||||
const deltas = [];
|
|
||||||
if (!state.toolName) {
|
|
||||||
const name = extractToolCallName(captured, callStart);
|
|
||||||
if (!name) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
state.toolName = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.toolArgsStart < 0) {
|
|
||||||
const args = findToolCallArgsStart(captured, callStart);
|
|
||||||
if (args) {
|
|
||||||
state.toolArgsString = Boolean(args.stringMode);
|
|
||||||
state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
|
|
||||||
state.toolArgsSent = state.toolArgsStart;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!state.toolNameSent) {
|
|
||||||
if (state.toolArgsStart < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
state.toolNameSent = true;
|
|
||||||
deltas.push({ index: 0, name: state.toolName });
|
|
||||||
}
|
|
||||||
if (state.toolArgsStart < 0 || state.toolArgsDone) {
|
|
||||||
return deltas;
|
|
||||||
}
|
|
||||||
const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
|
|
||||||
if (!progress) {
|
|
||||||
return deltas;
|
|
||||||
}
|
|
||||||
if (progress.end > state.toolArgsSent) {
|
|
||||||
deltas.push({
|
|
||||||
index: 0,
|
|
||||||
arguments: captured.slice(state.toolArgsSent, progress.end),
|
|
||||||
});
|
|
||||||
state.toolArgsSent = progress.end;
|
|
||||||
}
|
|
||||||
if (progress.complete) {
|
|
||||||
state.toolArgsDone = true;
|
|
||||||
}
|
|
||||||
return deltas;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findFirstToolCallObjectStart(text, keyIdx) {
|
|
||||||
const arrStart = findToolCallsArrayStart(text, keyIdx);
|
|
||||||
if (arrStart < 0) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
const i = skipSpaces(text, arrStart + 1);
|
|
||||||
if (i >= text.length || text[i] !== '{') {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findToolCallsArrayStart(text, keyIdx) {
|
|
||||||
let i = keyIdx + 'tool_calls'.length;
|
|
||||||
while (i < text.length && text[i] !== ':') {
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
if (i >= text.length) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
i = skipSpaces(text, i + 1);
|
|
||||||
if (i >= text.length || text[i] !== '[') {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolCallName(text, callStart) {
|
|
||||||
let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
|
|
||||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
|
||||||
const fnStart = findFunctionObjectStart(text, callStart);
|
|
||||||
if (fnStart < 0) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
|
|
||||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const parsed = parseJSONStringLiteral(text, valueStart);
|
|
||||||
if (!parsed) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return parsed.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findToolCallArgsStart(text, callStart) {
|
|
||||||
const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
|
|
||||||
let valueStart = findObjectFieldValueStart(text, callStart, keys);
|
|
||||||
if (valueStart < 0) {
|
|
||||||
const fnStart = findFunctionObjectStart(text, callStart);
|
|
||||||
if (fnStart < 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
valueStart = findObjectFieldValueStart(text, fnStart, keys);
|
|
||||||
if (valueStart < 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (valueStart >= text.length) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const ch = text[valueStart];
|
|
||||||
if (ch === '{' || ch === '[') {
|
|
||||||
return { start: valueStart, stringMode: false };
|
|
||||||
}
|
|
||||||
if (ch === '"') {
|
|
||||||
return { start: valueStart, stringMode: true };
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
function scanToolCallArgsProgress(text, start, stringMode) {
|
|
||||||
if (start < 0 || start > text.length) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
if (stringMode) {
|
|
||||||
let escaped = false;
|
|
||||||
for (let i = start; i < text.length; i += 1) {
|
|
||||||
const ch = text[i];
|
|
||||||
if (escaped) {
|
|
||||||
escaped = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '\\') {
|
|
||||||
escaped = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '"') {
|
|
||||||
return { end: i, complete: true };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { end: text.length, complete: false };
|
|
||||||
}
|
|
||||||
if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
let depth = 0;
|
|
||||||
let quote = '';
|
|
||||||
let escaped = false;
|
|
||||||
for (let i = start; i < text.length; i += 1) {
|
|
||||||
const ch = text[i];
|
|
||||||
if (quote) {
|
|
||||||
if (escaped) {
|
|
||||||
escaped = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '\\') {
|
|
||||||
escaped = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === quote) {
|
|
||||||
quote = '';
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '"' || ch === "'") {
|
|
||||||
quote = ch;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '{' || ch === '[') {
|
|
||||||
depth += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '}' || ch === ']') {
|
|
||||||
depth -= 1;
|
|
||||||
if (depth === 0) {
|
|
||||||
return { end: i + 1, complete: true };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { end: text.length, complete: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
function findFunctionObjectStart(text, callStart) {
|
|
||||||
const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
|
|
||||||
if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return valueStart;
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = {
|
|
||||||
buildIncrementalToolDeltas,
|
|
||||||
};
|
|
||||||
@@ -10,7 +10,9 @@ const {
|
|||||||
const {
|
const {
|
||||||
extractToolNames,
|
extractToolNames,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
parseStandaloneToolCallsDetailed,
|
||||||
} = require('./parse');
|
} = require('./parse');
|
||||||
const {
|
const {
|
||||||
formatOpenAIStreamToolCalls,
|
formatOpenAIStreamToolCalls,
|
||||||
@@ -22,6 +24,8 @@ module.exports = {
|
|||||||
processToolSieveChunk,
|
processToolSieveChunk,
|
||||||
flushToolSieve,
|
flushToolSieve,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
parseStandaloneToolCallsDetailed,
|
||||||
formatOpenAIStreamToolCalls,
|
formatOpenAIStreamToolCalls,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
toStringSafe,
|
toStringSafe,
|
||||||
looksLikeToolExampleContext,
|
|
||||||
} = require('./state');
|
} = require('./state');
|
||||||
const {
|
const {
|
||||||
extractJSONObjectFrom,
|
buildToolCallCandidates,
|
||||||
} = require('./jsonscan');
|
parseToolCallsPayload,
|
||||||
|
parseMarkupToolCalls,
|
||||||
|
parseTextKVToolCalls,
|
||||||
|
} = require('./parse_payload');
|
||||||
|
|
||||||
|
const TOOL_NAME_LOOSE_PATTERN = /[^a-z0-9]+/g;
|
||||||
|
|
||||||
function extractToolNames(tools) {
|
function extractToolNames(tools) {
|
||||||
if (!Array.isArray(tools) || tools.length === 0) {
|
if (!Array.isArray(tools) || tools.length === 0) {
|
||||||
@@ -29,245 +31,206 @@ function extractToolNames(tools) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function parseToolCalls(text, toolNames) {
|
function parseToolCalls(text, toolNames) {
|
||||||
if (!toStringSafe(text)) {
|
return parseToolCallsDetailed(text, toolNames).calls;
|
||||||
return [];
|
}
|
||||||
|
|
||||||
|
function parseToolCallsDetailed(text, toolNames) {
|
||||||
|
const result = emptyParseResult();
|
||||||
|
const normalized = toStringSafe(text);
|
||||||
|
if (!normalized) {
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
const sanitized = stripFencedCodeBlocks(text);
|
result.sawToolCallSyntax = looksLikeToolCallSyntax(normalized);
|
||||||
if (!toStringSafe(sanitized)) {
|
|
||||||
return [];
|
const candidates = buildToolCallCandidates(normalized);
|
||||||
}
|
|
||||||
const candidates = buildToolCallCandidates(sanitized);
|
|
||||||
let parsed = [];
|
let parsed = [];
|
||||||
for (const c of candidates) {
|
for (const c of candidates) {
|
||||||
parsed = parseToolCallsPayload(c);
|
parsed = parseToolCallsPayload(c);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseMarkupToolCalls(c);
|
||||||
|
}
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseTextKVToolCalls(c);
|
||||||
|
}
|
||||||
|
if (parsed.length > 0) {
|
||||||
|
result.sawToolCallSyntax = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseMarkupToolCalls(normalized);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseTextKVToolCalls(normalized);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.sawToolCallSyntax = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||||
|
result.calls = filtered.calls;
|
||||||
|
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||||
|
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseStandaloneToolCalls(text, toolNames) {
|
||||||
|
return parseStandaloneToolCallsDetailed(text, toolNames).calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseStandaloneToolCallsDetailed(text, toolNames) {
|
||||||
|
const result = emptyParseResult();
|
||||||
|
const trimmed = toStringSafe(text);
|
||||||
|
if (!trimmed) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
result.sawToolCallSyntax = looksLikeToolCallSyntax(trimmed);
|
||||||
|
const candidates = buildToolCallCandidates(trimmed);
|
||||||
|
let parsed = [];
|
||||||
|
for (const c of candidates) {
|
||||||
|
parsed = parseToolCallsPayload(c);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseMarkupToolCalls(c);
|
||||||
|
}
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseTextKVToolCalls(c);
|
||||||
|
}
|
||||||
if (parsed.length > 0) {
|
if (parsed.length > 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (parsed.length === 0) {
|
if (parsed.length === 0) {
|
||||||
return [];
|
parsed = parseMarkupToolCalls(trimmed);
|
||||||
}
|
if (parsed.length === 0) {
|
||||||
return filterToolCalls(parsed, toolNames);
|
parsed = parseTextKVToolCalls(trimmed);
|
||||||
}
|
if (parsed.length === 0) {
|
||||||
|
return result;
|
||||||
function stripFencedCodeBlocks(text) {
|
|
||||||
const t = typeof text === 'string' ? text : '';
|
|
||||||
if (!t) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return t.replace(/```[\s\S]*?```/g, ' ');
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseStandaloneToolCalls(text, toolNames) {
|
|
||||||
const trimmed = toStringSafe(text);
|
|
||||||
if (!trimmed) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (looksLikeToolExampleContext(trimmed)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const candidates = [trimmed];
|
|
||||||
if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
|
|
||||||
const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(toStringSafe(m[1]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const candidate of candidates) {
|
|
||||||
const c = toStringSafe(candidate);
|
|
||||||
if (!c) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!c.startsWith('{') && !c.startsWith('[')) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const parsed = parseToolCallsPayload(c);
|
|
||||||
if (parsed.length > 0) {
|
|
||||||
return filterToolCalls(parsed, toolNames);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildToolCallCandidates(text) {
|
|
||||||
const trimmed = toStringSafe(text);
|
|
||||||
const candidates = [trimmed];
|
|
||||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
|
||||||
for (const block of fenced) {
|
|
||||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(toStringSafe(m[1]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
|
||||||
candidates.push(toStringSafe(candidate));
|
|
||||||
}
|
|
||||||
const first = trimmed.indexOf('{');
|
|
||||||
const last = trimmed.lastIndexOf('}');
|
|
||||||
if (first >= 0 && last > first) {
|
|
||||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
|
||||||
}
|
|
||||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
|
||||||
}
|
|
||||||
return [...new Set(candidates.filter(Boolean))];
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolCallObjects(text) {
|
|
||||||
const raw = toStringSafe(text);
|
|
||||||
if (!raw) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const lower = raw.toLowerCase();
|
|
||||||
const out = [];
|
|
||||||
let offset = 0;
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
while (true) {
|
|
||||||
let idx = lower.indexOf('tool_calls', offset);
|
|
||||||
if (idx < 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
|
||||||
while (start >= 0) {
|
|
||||||
const obj = extractJSONObjectFrom(raw, start);
|
|
||||||
if (obj.ok) {
|
|
||||||
out.push(raw.slice(start, obj.end).trim());
|
|
||||||
offset = obj.end;
|
|
||||||
idx = -1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
start = raw.slice(0, start).lastIndexOf('{');
|
|
||||||
}
|
|
||||||
if (idx >= 0) {
|
|
||||||
offset = idx + 'tool_calls'.length;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallsPayload(payload) {
|
|
||||||
let decoded;
|
|
||||||
try {
|
|
||||||
decoded = JSON.parse(payload);
|
|
||||||
} catch (_err) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (Array.isArray(decoded)) {
|
|
||||||
return parseToolCallList(decoded);
|
|
||||||
}
|
|
||||||
if (!decoded || typeof decoded !== 'object') {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (decoded.tool_calls) {
|
|
||||||
return parseToolCallList(decoded.tool_calls);
|
|
||||||
}
|
|
||||||
const one = parseToolCallItem(decoded);
|
|
||||||
return one ? [one] : [];
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallList(v) {
|
|
||||||
if (!Array.isArray(v)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const out = [];
|
|
||||||
for (const item of v) {
|
|
||||||
if (!item || typeof item !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const one = parseToolCallItem(item);
|
|
||||||
if (one) {
|
|
||||||
out.push(one);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallItem(m) {
|
|
||||||
let name = toStringSafe(m.name);
|
|
||||||
let inputRaw = m.input;
|
|
||||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
|
||||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
|
||||||
if (fn) {
|
|
||||||
if (!name) {
|
|
||||||
name = toStringSafe(fn.name);
|
|
||||||
}
|
|
||||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
|
||||||
inputRaw = fn.arguments;
|
|
||||||
hasInput = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!hasInput) {
|
|
||||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
|
||||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
|
||||||
inputRaw = m[k];
|
|
||||||
hasInput = true;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!name) {
|
|
||||||
return null;
|
result.sawToolCallSyntax = true;
|
||||||
}
|
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||||
|
result.calls = filtered.calls;
|
||||||
|
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||||
|
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
function emptyParseResult() {
|
||||||
return {
|
return {
|
||||||
name,
|
calls: [],
|
||||||
input: parseToolCallInput(inputRaw),
|
sawToolCallSyntax: false,
|
||||||
|
rejectedByPolicy: false,
|
||||||
|
rejectedToolNames: [],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseToolCallInput(v) {
|
function filterToolCallsDetailed(parsed, toolNames) {
|
||||||
if (v == null) {
|
const sourceNames = Array.isArray(toolNames) ? toolNames : [];
|
||||||
return {};
|
const allowed = new Set();
|
||||||
}
|
const allowedCanonical = new Map();
|
||||||
if (typeof v === 'string') {
|
for (const item of sourceNames) {
|
||||||
const raw = toStringSafe(v);
|
const name = toStringSafe(item);
|
||||||
if (!raw) {
|
if (!name) {
|
||||||
return {};
|
continue;
|
||||||
}
|
}
|
||||||
try {
|
allowed.add(name);
|
||||||
const parsed = JSON.parse(raw);
|
const lower = name.toLowerCase();
|
||||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
if (!allowedCanonical.has(lower)) {
|
||||||
return parsed;
|
allowedCanonical.set(lower, name);
|
||||||
}
|
|
||||||
return { _raw: raw };
|
|
||||||
} catch (_err) {
|
|
||||||
return { _raw: raw };
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (typeof v === 'object' && !Array.isArray(v)) {
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(JSON.stringify(v));
|
|
||||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
|
||||||
return parsed;
|
|
||||||
}
|
|
||||||
} catch (_err) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
function filterToolCalls(parsed, toolNames) {
|
if (allowed.size === 0) {
|
||||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
const rejected = [];
|
||||||
const out = [];
|
const seen = new Set();
|
||||||
|
for (const tc of parsed) {
|
||||||
|
if (!tc || !tc.name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (seen.has(tc.name)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.add(tc.name);
|
||||||
|
rejected.push(tc.name);
|
||||||
|
}
|
||||||
|
return { calls: [], rejectedToolNames: rejected };
|
||||||
|
}
|
||||||
|
|
||||||
|
const calls = [];
|
||||||
|
const rejected = [];
|
||||||
|
const seenRejected = new Set();
|
||||||
for (const tc of parsed) {
|
for (const tc of parsed) {
|
||||||
if (!tc || !tc.name) {
|
if (!tc || !tc.name) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
let matchedName = '';
|
||||||
|
if (allowed.has(tc.name)) {
|
||||||
|
matchedName = tc.name;
|
||||||
|
} else {
|
||||||
|
matchedName = resolveAllowedToolName(tc.name, allowed, allowedCanonical);
|
||||||
|
}
|
||||||
|
if (!matchedName) {
|
||||||
|
if (!seenRejected.has(tc.name)) {
|
||||||
|
seenRejected.add(tc.name);
|
||||||
|
rejected.push(tc.name);
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
out.push({ name: tc.name, input: tc.input || {} });
|
calls.push({
|
||||||
|
name: matchedName,
|
||||||
|
input: tc.input && typeof tc.input === 'object' && !Array.isArray(tc.input) ? tc.input : {},
|
||||||
|
});
|
||||||
}
|
}
|
||||||
return out;
|
return { calls, rejectedToolNames: rejected };
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveAllowedToolName(name, allowed, allowedCanonical) {
|
||||||
|
const normalizedName = toStringSafe(name).trim();
|
||||||
|
if (!normalizedName) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
if (allowed.has(normalizedName)) {
|
||||||
|
return normalizedName;
|
||||||
|
}
|
||||||
|
const lower = normalizedName.toLowerCase();
|
||||||
|
if (allowedCanonical.has(lower)) {
|
||||||
|
return allowedCanonical.get(lower);
|
||||||
|
}
|
||||||
|
const idx = lower.lastIndexOf('.');
|
||||||
|
if (idx >= 0 && idx < lower.length - 1) {
|
||||||
|
const tail = lower.slice(idx + 1);
|
||||||
|
if (allowedCanonical.has(tail)) {
|
||||||
|
return allowedCanonical.get(tail);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const loose = lower.replace(TOOL_NAME_LOOSE_PATTERN, '');
|
||||||
|
if (!loose) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
for (const [candidateLower, canonical] of allowedCanonical.entries()) {
|
||||||
|
if (candidateLower.replace(TOOL_NAME_LOOSE_PATTERN, '') === loose) {
|
||||||
|
return canonical;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function looksLikeToolCallSyntax(text) {
|
||||||
|
const lower = toStringSafe(text).toLowerCase();
|
||||||
|
return lower.includes('tool_calls')
|
||||||
|
|| lower.includes('<tool_call')
|
||||||
|
|| lower.includes('<function_call')
|
||||||
|
|| lower.includes('<invoke')
|
||||||
|
|| lower.includes('function.name:');
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
extractToolNames,
|
extractToolNames,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
parseStandaloneToolCallsDetailed,
|
||||||
};
|
};
|
||||||
|
|||||||
363
internal/js/helpers/stream-tool-sieve/parse_payload.js
Normal file
363
internal/js/helpers/stream-tool-sieve/parse_payload.js
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
'use strict';
|
||||||
|
|
||||||
|
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||||
|
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?(tool_call|function_call|invoke)\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||||
|
const TOOL_CALL_MARKUP_SELFCLOSE_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)\/>/gi;
|
||||||
|
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||||
|
const TOOL_CALL_MARKUP_ATTR_PATTERN = /(name|function|tool)\s*=\s*"([^"]+)"/i;
|
||||||
|
const TOOL_CALL_MARKUP_NAME_PATTERNS = [
|
||||||
|
/<(?:[a-z0-9_:-]+:)?name\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?name>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?function\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?function>/i,
|
||||||
|
];
|
||||||
|
const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
|
||||||
|
/<(?:[a-z0-9_:-]+:)?input\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?input>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?arguments>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?argument\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?argument>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameters>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
|
||||||
|
];
|
||||||
|
const TEXT_KV_NAME_PATTERN = /function\.name:\s*([a-zA-Z0-9_.-]+)/gi;
|
||||||
|
|
||||||
|
const {
|
||||||
|
toStringSafe,
|
||||||
|
} = require('./state');
|
||||||
|
const {
|
||||||
|
extractJSONObjectFrom,
|
||||||
|
} = require('./jsonscan');
|
||||||
|
|
||||||
|
function stripFencedCodeBlocks(text) {
|
||||||
|
const t = typeof text === 'string' ? text : '';
|
||||||
|
if (!t) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
return t.replace(/```[\s\S]*?```/g, ' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildToolCallCandidates(text) {
|
||||||
|
const trimmed = toStringSafe(text);
|
||||||
|
const candidates = [trimmed];
|
||||||
|
|
||||||
|
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
||||||
|
for (const block of fenced) {
|
||||||
|
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||||
|
if (m && m[1]) {
|
||||||
|
candidates.push(toStringSafe(m[1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const candidate of extractToolCallObjects(trimmed)) {
|
||||||
|
candidates.push(toStringSafe(candidate));
|
||||||
|
}
|
||||||
|
|
||||||
|
const first = trimmed.indexOf('{');
|
||||||
|
const last = trimmed.lastIndexOf('}');
|
||||||
|
if (first >= 0 && last > first) {
|
||||||
|
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const m = trimmed.match(TOOL_CALL_PATTERN);
|
||||||
|
if (m && m[1]) {
|
||||||
|
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return [...new Set(candidates.filter(Boolean))];
|
||||||
|
}
|
||||||
|
|
||||||
|
function extractToolCallObjects(text) {
|
||||||
|
const raw = toStringSafe(text);
|
||||||
|
if (!raw) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const lower = raw.toLowerCase();
|
||||||
|
const out = [];
|
||||||
|
let offset = 0;
|
||||||
|
|
||||||
|
// eslint-disable-next-line no-constant-condition
|
||||||
|
while (true) {
|
||||||
|
let idx = lower.indexOf('tool_calls', offset);
|
||||||
|
if (idx < 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let start = raw.slice(0, idx).lastIndexOf('{');
|
||||||
|
while (start >= 0) {
|
||||||
|
const obj = extractJSONObjectFrom(raw, start);
|
||||||
|
if (obj.ok) {
|
||||||
|
out.push(raw.slice(start, obj.end).trim());
|
||||||
|
offset = obj.end;
|
||||||
|
idx = -1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
start = raw.slice(0, start).lastIndexOf('{');
|
||||||
|
}
|
||||||
|
if (idx >= 0) {
|
||||||
|
offset = idx + 'tool_calls'.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallsPayload(payload) {
|
||||||
|
let decoded;
|
||||||
|
try {
|
||||||
|
decoded = JSON.parse(payload);
|
||||||
|
} catch (_err) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(decoded)) {
|
||||||
|
return parseToolCallList(decoded);
|
||||||
|
}
|
||||||
|
if (!decoded || typeof decoded !== 'object') {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
if (decoded.tool_calls) {
|
||||||
|
return parseToolCallList(decoded.tool_calls);
|
||||||
|
}
|
||||||
|
|
||||||
|
const one = parseToolCallItem(decoded);
|
||||||
|
return one ? [one] : [];
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupToolCalls(text) {
|
||||||
|
const raw = toStringSafe(text).trim();
|
||||||
|
if (!raw) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const out = [];
|
||||||
|
for (const m of raw.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
|
||||||
|
const parsed = parseMarkupSingleToolCall(toStringSafe(m[2]).trim(), toStringSafe(m[3]).trim());
|
||||||
|
if (parsed) {
|
||||||
|
out.push(parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const m of raw.matchAll(TOOL_CALL_MARKUP_SELFCLOSE_PATTERN)) {
|
||||||
|
const parsed = parseMarkupSingleToolCall(toStringSafe(m[1]).trim(), '');
|
||||||
|
if (parsed) {
|
||||||
|
out.push(parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseTextKVToolCalls(text) {
|
||||||
|
const raw = toStringSafe(text);
|
||||||
|
if (!raw) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const out = [];
|
||||||
|
const matches = [...raw.matchAll(TEXT_KV_NAME_PATTERN)];
|
||||||
|
if (matches.length === 0) {
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
for (let i = 0; i < matches.length; i += 1) {
|
||||||
|
const match = matches[i];
|
||||||
|
const name = toStringSafe(match[1]).trim();
|
||||||
|
if (!name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const nameEnd = match.index + toStringSafe(match[0]).length;
|
||||||
|
const searchEnd = i + 1 < matches.length ? matches[i + 1].index : raw.length;
|
||||||
|
const searchArea = raw.slice(nameEnd, searchEnd);
|
||||||
|
const argIdx = searchArea.indexOf('function.arguments:');
|
||||||
|
if (argIdx < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const argStart = nameEnd + argIdx + 'function.arguments:'.length;
|
||||||
|
const bracePos = raw.slice(argStart, searchEnd).indexOf('{');
|
||||||
|
if (bracePos < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const objStart = argStart + bracePos;
|
||||||
|
const obj = extractJSONObjectFrom(raw, objStart);
|
||||||
|
if (!obj.ok) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out.push({
|
||||||
|
name,
|
||||||
|
input: parseToolCallInput(raw.slice(objStart, obj.end)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupSingleToolCall(attrs, inner) {
|
||||||
|
const embedded = parseToolCallsPayload(inner);
|
||||||
|
if (embedded.length > 0) {
|
||||||
|
return embedded[0];
|
||||||
|
}
|
||||||
|
let name = '';
|
||||||
|
const attrMatch = attrs.match(TOOL_CALL_MARKUP_ATTR_PATTERN);
|
||||||
|
if (attrMatch && attrMatch[2]) {
|
||||||
|
name = toStringSafe(attrMatch[2]).trim();
|
||||||
|
}
|
||||||
|
if (!name) {
|
||||||
|
name = stripTagText(findMarkupTagValue(inner, TOOL_CALL_MARKUP_NAME_PATTERNS));
|
||||||
|
}
|
||||||
|
if (!name) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = {};
|
||||||
|
const argsRaw = findMarkupTagValue(inner, TOOL_CALL_MARKUP_ARGS_PATTERNS);
|
||||||
|
if (argsRaw) {
|
||||||
|
input = parseMarkupInput(argsRaw);
|
||||||
|
} else {
|
||||||
|
const kv = parseMarkupKVObject(inner);
|
||||||
|
if (Object.keys(kv).length > 0) {
|
||||||
|
input = kv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { name, input };
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupInput(raw) {
|
||||||
|
const s = toStringSafe(raw).trim();
|
||||||
|
if (!s) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const parsed = parseToolCallInput(s);
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
const kv = parseMarkupKVObject(s);
|
||||||
|
if (Object.keys(kv).length > 0) {
|
||||||
|
return kv;
|
||||||
|
}
|
||||||
|
return { _raw: stripTagText(s) };
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupKVObject(text) {
|
||||||
|
const raw = toStringSafe(text).trim();
|
||||||
|
if (!raw) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const out = {};
|
||||||
|
for (const m of raw.matchAll(TOOL_CALL_MARKUP_KV_PATTERN)) {
|
||||||
|
const key = toStringSafe(m[1]).trim();
|
||||||
|
if (!key) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const valueRaw = stripTagText(m[2]);
|
||||||
|
if (!valueRaw) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
out[key] = JSON.parse(valueRaw);
|
||||||
|
} catch (_err) {
|
||||||
|
out[key] = valueRaw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function stripTagText(text) {
|
||||||
|
return toStringSafe(text).replace(/<[^>]+>/g, ' ').trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
function findMarkupTagValue(text, patterns) {
|
||||||
|
const source = toStringSafe(text);
|
||||||
|
for (const p of patterns) {
|
||||||
|
const m = source.match(p);
|
||||||
|
if (m && m[1]) {
|
||||||
|
return toStringSafe(m[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallList(v) {
|
||||||
|
if (!Array.isArray(v)) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const out = [];
|
||||||
|
for (const item of v) {
|
||||||
|
if (!item || typeof item !== 'object') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const one = parseToolCallItem(item);
|
||||||
|
if (one) {
|
||||||
|
out.push(one);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallItem(m) {
|
||||||
|
let name = toStringSafe(m.name);
|
||||||
|
let inputRaw = m.input;
|
||||||
|
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||||
|
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||||
|
|
||||||
|
if (fn) {
|
||||||
|
if (!name) {
|
||||||
|
name = toStringSafe(fn.name);
|
||||||
|
}
|
||||||
|
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
||||||
|
inputRaw = fn.arguments;
|
||||||
|
hasInput = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasInput) {
|
||||||
|
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
||||||
|
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
||||||
|
inputRaw = m[k];
|
||||||
|
hasInput = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!name) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
name,
|
||||||
|
input: parseToolCallInput(inputRaw),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallInput(v) {
|
||||||
|
if (v == null) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
if (typeof v === 'string') {
|
||||||
|
const raw = toStringSafe(v);
|
||||||
|
if (!raw) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(raw);
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
return { _raw: raw };
|
||||||
|
} catch (_err) {
|
||||||
|
return { _raw: raw };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (typeof v === 'object' && !Array.isArray(v)) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(JSON.stringify(v));
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
} catch (_err) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
stripFencedCodeBlocks,
|
||||||
|
buildToolCallCandidates,
|
||||||
|
parseToolCallsPayload,
|
||||||
|
parseMarkupToolCalls,
|
||||||
|
parseTextKVToolCalls,
|
||||||
|
};
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const {
|
const {
|
||||||
TOOL_SIEVE_CAPTURE_LIMIT,
|
|
||||||
resetIncrementalToolState,
|
resetIncrementalToolState,
|
||||||
noteText,
|
noteText,
|
||||||
insideCodeFence,
|
insideCodeFence,
|
||||||
} = require('./state');
|
} = require('./state');
|
||||||
const {
|
const {
|
||||||
buildIncrementalToolDeltas,
|
parseStandaloneToolCallsDetailed,
|
||||||
} = require('./incremental');
|
|
||||||
const {
|
|
||||||
parseStandaloneToolCalls,
|
|
||||||
} = require('./parse');
|
} = require('./parse');
|
||||||
const {
|
const {
|
||||||
extractJSONObjectFrom,
|
extractJSONObjectFrom,
|
||||||
@@ -24,64 +20,67 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
|||||||
state.pending += chunk;
|
state.pending += chunk;
|
||||||
}
|
}
|
||||||
const events = [];
|
const events = [];
|
||||||
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
// eslint-disable-next-line no-constant-condition
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||||
|
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||||
|
state.pendingToolRaw = '';
|
||||||
|
state.pendingToolCalls = [];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (state.capturing) {
|
if (state.capturing) {
|
||||||
if (state.pending) {
|
if (state.pending) {
|
||||||
state.capture += state.pending;
|
state.capture += state.pending;
|
||||||
state.pending = '';
|
state.pending = '';
|
||||||
}
|
}
|
||||||
const deltas = buildIncrementalToolDeltas(state);
|
|
||||||
if (deltas.length > 0) {
|
|
||||||
events.push({ type: 'tool_call_deltas', deltas });
|
|
||||||
}
|
|
||||||
const consumed = consumeToolCapture(state, toolNames);
|
const consumed = consumeToolCapture(state, toolNames);
|
||||||
if (!consumed.ready) {
|
if (!consumed.ready) {
|
||||||
if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
|
|
||||||
noteText(state, state.capture);
|
|
||||||
events.push({ type: 'text', text: state.capture });
|
|
||||||
state.capture = '';
|
|
||||||
state.capturing = false;
|
|
||||||
resetIncrementalToolState(state);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
const captured = state.capture;
|
||||||
state.capture = '';
|
state.capture = '';
|
||||||
state.capturing = false;
|
state.capturing = false;
|
||||||
resetIncrementalToolState(state);
|
resetIncrementalToolState(state);
|
||||||
|
|
||||||
|
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||||
|
state.pendingToolRaw = captured;
|
||||||
|
state.pendingToolCalls = consumed.calls;
|
||||||
|
if (consumed.suffix) {
|
||||||
|
state.pending = consumed.suffix + state.pending;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (consumed.prefix) {
|
if (consumed.prefix) {
|
||||||
noteText(state, consumed.prefix);
|
noteText(state, consumed.prefix);
|
||||||
events.push({ type: 'text', text: consumed.prefix });
|
events.push({ type: 'text', text: consumed.prefix });
|
||||||
}
|
}
|
||||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
|
||||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
|
||||||
}
|
|
||||||
if (consumed.suffix) {
|
if (consumed.suffix) {
|
||||||
state.pending += consumed.suffix;
|
state.pending += consumed.suffix;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!state.pending) {
|
const pending = state.pending || '';
|
||||||
|
if (!pending) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const start = findToolSegmentStart(state.pending);
|
const start = findToolSegmentStart(pending);
|
||||||
if (start >= 0) {
|
if (start >= 0) {
|
||||||
const prefix = state.pending.slice(0, start);
|
const prefix = pending.slice(0, start);
|
||||||
if (prefix) {
|
if (prefix) {
|
||||||
noteText(state, prefix);
|
noteText(state, prefix);
|
||||||
events.push({ type: 'text', text: prefix });
|
events.push({ type: 'text', text: prefix });
|
||||||
}
|
}
|
||||||
state.capture = state.pending.slice(start);
|
|
||||||
state.pending = '';
|
state.pending = '';
|
||||||
|
state.capture += pending.slice(start);
|
||||||
state.capturing = true;
|
state.capturing = true;
|
||||||
resetIncrementalToolState(state);
|
resetIncrementalToolState(state);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const [safe, hold] = splitSafeContentForToolDetection(state.pending);
|
const [safe, hold] = splitSafeContentForToolDetection(pending);
|
||||||
if (!safe) {
|
if (!safe) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -97,6 +96,13 @@ function flushToolSieve(state, toolNames) {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
const events = processToolSieveChunk(state, '', toolNames);
|
const events = processToolSieveChunk(state, '', toolNames);
|
||||||
|
|
||||||
|
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||||
|
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||||
|
state.pendingToolRaw = '';
|
||||||
|
state.pendingToolCalls = [];
|
||||||
|
}
|
||||||
|
|
||||||
if (state.capturing) {
|
if (state.capturing) {
|
||||||
const consumed = consumeToolCapture(state, toolNames);
|
const consumed = consumeToolCapture(state, toolNames);
|
||||||
if (consumed.ready) {
|
if (consumed.ready) {
|
||||||
@@ -119,11 +125,13 @@ function flushToolSieve(state, toolNames) {
|
|||||||
state.capturing = false;
|
state.capturing = false;
|
||||||
resetIncrementalToolState(state);
|
resetIncrementalToolState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.pending) {
|
if (state.pending) {
|
||||||
noteText(state, state.pending);
|
noteText(state, state.pending);
|
||||||
events.push({ type: 'text', text: state.pending });
|
events.push({ type: 'text', text: state.pending });
|
||||||
state.pending = '';
|
state.pending = '';
|
||||||
}
|
}
|
||||||
|
|
||||||
return events;
|
return events;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,43 +168,67 @@ function findToolSegmentStart(s) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
const lower = s.toLowerCase();
|
const lower = s.toLowerCase();
|
||||||
|
const keywords = ['tool_calls', 'function.name:', '[tool_call_history]'];
|
||||||
let offset = 0;
|
let offset = 0;
|
||||||
// eslint-disable-next-line no-constant-condition
|
// eslint-disable-next-line no-constant-condition
|
||||||
while (true) {
|
while (true) {
|
||||||
const keyRel = lower.indexOf('tool_calls', offset);
|
let bestKeyIdx = -1;
|
||||||
if (keyRel < 0) {
|
let matchedKeyword = '';
|
||||||
|
|
||||||
|
for (const kw of keywords) {
|
||||||
|
const idx = lower.indexOf(kw, offset);
|
||||||
|
if (idx >= 0) {
|
||||||
|
if (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||||
|
bestKeyIdx = idx;
|
||||||
|
matchedKeyword = kw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bestKeyIdx < 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
const keyIdx = keyRel;
|
|
||||||
|
const keyIdx = bestKeyIdx;
|
||||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||||
const candidateStart = start >= 0 ? start : keyIdx;
|
const candidateStart = start >= 0 ? start : keyIdx;
|
||||||
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
||||||
return candidateStart;
|
return candidateStart;
|
||||||
}
|
}
|
||||||
offset = keyIdx + 'tool_calls'.length;
|
offset = keyIdx + matchedKeyword.length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function consumeToolCapture(state, toolNames) {
|
function consumeToolCapture(state, toolNames) {
|
||||||
const captured = state.capture;
|
const captured = state.capture || '';
|
||||||
if (!captured) {
|
if (!captured) {
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||||
}
|
}
|
||||||
const lower = captured.toLowerCase();
|
const lower = captured.toLowerCase();
|
||||||
const keyIdx = lower.indexOf('tool_calls');
|
|
||||||
|
let keyIdx = -1;
|
||||||
|
const keywords = ['tool_calls', 'function.name:', '[tool_call_history]'];
|
||||||
|
for (const kw of keywords) {
|
||||||
|
const idx = lower.indexOf(kw);
|
||||||
|
if (idx >= 0 && (keyIdx < 0 || idx < keyIdx)) {
|
||||||
|
keyIdx = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (keyIdx < 0) {
|
if (keyIdx < 0) {
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||||
}
|
}
|
||||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
||||||
if (start < 0) {
|
const actualStart = start >= 0 ? start : keyIdx;
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
|
||||||
}
|
const obj = extractJSONObjectFrom(captured, actualStart);
|
||||||
const obj = extractJSONObjectFrom(captured, start);
|
|
||||||
if (!obj.ok) {
|
if (!obj.ok) {
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||||
}
|
}
|
||||||
const prefixPart = captured.slice(0, start);
|
|
||||||
|
const prefixPart = captured.slice(0, actualStart);
|
||||||
const suffixPart = captured.slice(obj.end);
|
const suffixPart = captured.slice(obj.end);
|
||||||
|
|
||||||
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
||||||
return {
|
return {
|
||||||
ready: true,
|
ready: true,
|
||||||
@@ -205,18 +237,10 @@ function consumeToolCapture(state, toolNames) {
|
|||||||
suffix: '',
|
suffix: '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
const rawParsed = parseStandaloneToolCalls(captured.slice(start, obj.end), []);
|
|
||||||
const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
|
const parsed = parseStandaloneToolCallsDetailed(captured.slice(actualStart, obj.end), toolNames);
|
||||||
if (parsed.length === 0) {
|
if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) {
|
||||||
if (rawParsed.length > 0 && Array.isArray(toolNames) && toolNames.length > 0) {
|
if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) {
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: prefixPart,
|
|
||||||
calls: [],
|
|
||||||
suffix: suffixPart,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (state.toolNameSent) {
|
|
||||||
return {
|
return {
|
||||||
ready: true,
|
ready: true,
|
||||||
prefix: prefixPart,
|
prefix: prefixPart,
|
||||||
@@ -231,26 +255,11 @@ function consumeToolCapture(state, toolNames) {
|
|||||||
suffix: '',
|
suffix: '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
if (state.toolNameSent) {
|
|
||||||
if (parsed.length > 1) {
|
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: prefixPart,
|
|
||||||
calls: parsed.slice(1),
|
|
||||||
suffix: suffixPart,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: prefixPart,
|
|
||||||
calls: [],
|
|
||||||
suffix: suffixPart,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
return {
|
||||||
ready: true,
|
ready: true,
|
||||||
prefix: prefixPart,
|
prefix: prefixPart,
|
||||||
calls: parsed,
|
calls: parsed.calls,
|
||||||
suffix: suffixPart,
|
suffix: suffixPart,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
|
|
||||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
||||||
|
|
||||||
function createToolSieveState() {
|
function createToolSieveState() {
|
||||||
@@ -9,6 +8,9 @@ function createToolSieveState() {
|
|||||||
capture: '',
|
capture: '',
|
||||||
capturing: false,
|
capturing: false,
|
||||||
recentTextTail: '',
|
recentTextTail: '',
|
||||||
|
pendingToolRaw: '',
|
||||||
|
pendingToolCalls: [],
|
||||||
|
disableDeltas: false,
|
||||||
toolNameSent: false,
|
toolNameSent: false,
|
||||||
toolName: '',
|
toolName: '',
|
||||||
toolArgsStart: -1,
|
toolArgsStart: -1,
|
||||||
@@ -19,6 +21,7 @@ function createToolSieveState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function resetIncrementalToolState(state) {
|
function resetIncrementalToolState(state) {
|
||||||
|
state.disableDeltas = false;
|
||||||
state.toolNameSent = false;
|
state.toolNameSent = false;
|
||||||
state.toolName = '';
|
state.toolName = '';
|
||||||
state.toolArgsStart = -1;
|
state.toolArgsStart = -1;
|
||||||
@@ -78,7 +81,6 @@ function toStringSafe(v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
TOOL_SIEVE_CAPTURE_LIMIT,
|
|
||||||
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
|
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
|
||||||
createToolSieveState,
|
createToolSieveState,
|
||||||
resetIncrementalToolState,
|
resetIncrementalToolState,
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ func MessagesPrepare(messages []map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NormalizeContent(v any) string {
|
func NormalizeContent(v any) string {
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
switch x := v.(type) {
|
switch x := v.(type) {
|
||||||
case string:
|
case string:
|
||||||
return x
|
return x
|
||||||
@@ -64,11 +67,11 @@ func NormalizeContent(v any) string {
|
|||||||
typeStr, _ := m["type"].(string)
|
typeStr, _ := m["type"].(string)
|
||||||
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
||||||
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
||||||
if txt, ok := m["text"].(string); ok {
|
if txt, ok := m["text"].(string); ok && txt != "" {
|
||||||
parts = append(parts, txt)
|
parts = append(parts, txt)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if txt, ok := m["content"].(string); ok {
|
if txt, ok := m["content"].(string); ok && txt != "" {
|
||||||
parts = append(parts, txt)
|
parts = append(parts, txt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
32
internal/prompt/messages_test.go
Normal file
32
internal/prompt/messages_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package prompt
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeContentNilReturnsEmpty(t *testing.T) {
|
||||||
|
if got := NormalizeContent(nil); got != "" {
|
||||||
|
t.Fatalf("expected empty string for nil content, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessagesPrepareNilContentNoNullLiteral(t *testing.T) {
|
||||||
|
messages := []map[string]any{
|
||||||
|
{"role": "assistant", "content": nil},
|
||||||
|
{"role": "user", "content": "ok"},
|
||||||
|
}
|
||||||
|
got := MessagesPrepare(messages)
|
||||||
|
if got == "" {
|
||||||
|
t.Fatalf("expected non-empty output")
|
||||||
|
}
|
||||||
|
if got == "null" {
|
||||||
|
t.Fatalf("expected no null literal output, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||||
|
got := NormalizeContent([]any{
|
||||||
|
map[string]any{"type": "text", "text": "", "content": "from-content"},
|
||||||
|
})
|
||||||
|
if got != "from-content" {
|
||||||
|
t.Fatalf("expected fallback to content when text is empty, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -57,16 +57,20 @@ func NewApp() *App {
|
|||||||
r.Use(cors)
|
r.Use(cors)
|
||||||
r.Use(timeout(0))
|
r.Use(timeout(0))
|
||||||
|
|
||||||
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||||
})
|
}
|
||||||
r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
readyzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
||||||
})
|
}
|
||||||
|
r.Get("/healthz", healthzHandler)
|
||||||
|
r.Head("/healthz", healthzHandler)
|
||||||
|
r.Get("/readyz", readyzHandler)
|
||||||
|
r.Head("/readyz", readyzHandler)
|
||||||
openai.RegisterRoutes(r, openaiHandler)
|
openai.RegisterRoutes(r, openaiHandler)
|
||||||
claude.RegisterRoutes(r, claudeHandler)
|
claude.RegisterRoutes(r, claudeHandler)
|
||||||
gemini.RegisterRoutes(r, geminiHandler)
|
gemini.RegisterRoutes(r, geminiHandler)
|
||||||
|
|||||||
20
internal/server/router_health_test.go
Normal file
20
internal/server/router_health_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthEndpointsSupportHEAD(t *testing.T) {
|
||||||
|
app := NewApp()
|
||||||
|
|
||||||
|
for _, path := range []string{"/healthz", "/readyz"} {
|
||||||
|
req := httptest.NewRequest(http.MethodHead, path, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.Router.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected %s HEAD status 200, got %d", path, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,12 @@ func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
|
|||||||
var m map[string]any
|
var m map[string]any
|
||||||
_ = json.Unmarshal(resp.Body, &m)
|
_ = json.Unmarshal(resp.Body, &m)
|
||||||
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||||
|
|
||||||
|
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/healthz", Retryable: true})
|
||||||
|
if headErr != nil {
|
||||||
|
return headErr
|
||||||
|
}
|
||||||
|
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +35,12 @@ func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
|
|||||||
var m map[string]any
|
var m map[string]any
|
||||||
_ = json.Unmarshal(resp.Body, &m)
|
_ = json.Unmarshal(resp.Body, &m)
|
||||||
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||||
|
|
||||||
|
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/readyz", Retryable: true})
|
||||||
|
if headErr != nil {
|
||||||
|
return headErr
|
||||||
|
}
|
||||||
|
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func buildToolCallCandidates(text string) []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// best-effort extraction around "tool_calls" key in mixed text payloads.
|
// best-effort extraction around tool call keywords in mixed text payloads.
|
||||||
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
||||||
|
|
||||||
// best-effort object slice: from first '{' to last '}'
|
// best-effort object slice: from first '{' to last '}'
|
||||||
@@ -57,25 +57,65 @@ func extractToolCallObjects(text string) []string {
|
|||||||
lower := strings.ToLower(text)
|
lower := strings.ToLower(text)
|
||||||
out := []string{}
|
out := []string{}
|
||||||
offset := 0
|
offset := 0
|
||||||
|
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
|
||||||
for {
|
for {
|
||||||
idx := strings.Index(lower[offset:], "tool_calls")
|
bestIdx := -1
|
||||||
if idx < 0 {
|
matchedKeyword := ""
|
||||||
|
for _, kw := range keywords {
|
||||||
|
idx := strings.Index(lower[offset:], kw)
|
||||||
|
if idx >= 0 {
|
||||||
|
absIdx := offset + idx
|
||||||
|
if bestIdx < 0 || absIdx < bestIdx {
|
||||||
|
bestIdx = absIdx
|
||||||
|
matchedKeyword = kw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestIdx < 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
idx += offset
|
|
||||||
start := strings.LastIndex(text[:idx], "{")
|
idx := bestIdx
|
||||||
for start >= 0 {
|
// Avoid backtracking too far to prevent OOM on malicious or very long strings
|
||||||
|
searchLimit := idx - 2000
|
||||||
|
if searchLimit < offset {
|
||||||
|
searchLimit = offset
|
||||||
|
}
|
||||||
|
|
||||||
|
start := strings.LastIndex(text[searchLimit:idx], "{")
|
||||||
|
if start >= 0 {
|
||||||
|
start += searchLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
if start < 0 {
|
||||||
|
offset = idx + len(matchedKeyword)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
foundObj := false
|
||||||
|
for start >= searchLimit {
|
||||||
candidate, end, ok := extractJSONObject(text, start)
|
candidate, end, ok := extractJSONObject(text, start)
|
||||||
if ok {
|
if ok {
|
||||||
// Move forward to avoid repeatedly matching the same object.
|
// Move forward to avoid repeatedly matching the same object.
|
||||||
offset = end
|
offset = end
|
||||||
out = append(out, strings.TrimSpace(candidate))
|
out = append(out, strings.TrimSpace(candidate))
|
||||||
|
foundObj = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
start = strings.LastIndex(text[:start], "{")
|
// Try previous '{'
|
||||||
|
if start > searchLimit {
|
||||||
|
prevStart := strings.LastIndex(text[searchLimit:start], "{")
|
||||||
|
if prevStart >= 0 {
|
||||||
|
start = searchLimit + prevStart
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if start < 0 {
|
|
||||||
offset = idx + len("tool_calls")
|
if !foundObj {
|
||||||
|
offset = idx + len(matchedKeyword)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
@@ -88,7 +128,12 @@ func extractJSONObject(text string, start int) (string, int, bool) {
|
|||||||
depth := 0
|
depth := 0
|
||||||
quote := byte(0)
|
quote := byte(0)
|
||||||
escaped := false
|
escaped := false
|
||||||
for i := start; i < len(text); i++ {
|
// Limit scan length to avoid OOM on unclosed objects
|
||||||
|
maxLen := start + 50000
|
||||||
|
if maxLen > len(text) {
|
||||||
|
maxLen = len(text)
|
||||||
|
}
|
||||||
|
for i := start; i < maxLen; i++ {
|
||||||
ch := text[i]
|
ch := text[i]
|
||||||
if quote != 0 {
|
if quote != 0 {
|
||||||
if escaped {
|
if escaped {
|
||||||
|
|||||||
106
internal/util/toolcalls_input_parse.go
Normal file
106
internal/util/toolcalls_input_parse.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseToolCallInput(v any) map[string]any {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return map[string]any{}
|
||||||
|
case map[string]any:
|
||||||
|
return x
|
||||||
|
case string:
|
||||||
|
raw := strings.TrimSpace(x)
|
||||||
|
if raw == "" {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
|
||||||
|
repairPathLikeControlChars(parsed)
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
// Try to repair invalid backslashes (common in Windows paths output by models)
|
||||||
|
repaired := repairInvalidJSONBackslashes(raw)
|
||||||
|
if repaired != raw {
|
||||||
|
if err := json.Unmarshal([]byte(repaired), &parsed); err == nil && parsed != nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try to repair loose JSON in string argument as well
|
||||||
|
repairedLoose := RepairLooseJSON(raw)
|
||||||
|
if repairedLoose != raw {
|
||||||
|
if err := json.Unmarshal([]byte(repairedLoose), &parsed); err == nil && parsed != nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]any{"_raw": raw}
|
||||||
|
default:
|
||||||
|
b, err := json.Marshal(x)
|
||||||
|
if err != nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func repairPathLikeControlChars(m map[string]any) {
|
||||||
|
for k, v := range m {
|
||||||
|
switch vv := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
repairPathLikeControlChars(vv)
|
||||||
|
case []any:
|
||||||
|
for _, item := range vv {
|
||||||
|
if child, ok := item.(map[string]any); ok {
|
||||||
|
repairPathLikeControlChars(child)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if isPathLikeKey(k) && containsControlRune(vv) {
|
||||||
|
m[k] = escapeControlRunes(vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPathLikeKey(key string) bool {
|
||||||
|
k := strings.ToLower(strings.TrimSpace(key))
|
||||||
|
return strings.Contains(k, "path") || strings.Contains(k, "file")
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsControlRune(s string) bool {
|
||||||
|
for _, r := range s {
|
||||||
|
if unicode.IsControl(r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeControlRunes(s string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s) + 8)
|
||||||
|
for _, r := range s {
|
||||||
|
switch r {
|
||||||
|
case '\b':
|
||||||
|
b.WriteString(`\b`)
|
||||||
|
case '\f':
|
||||||
|
b.WriteString(`\f`)
|
||||||
|
case '\n':
|
||||||
|
b.WriteString(`\n`)
|
||||||
|
case '\r':
|
||||||
|
b.WriteString(`\r`)
|
||||||
|
case '\t':
|
||||||
|
b.WriteString(`\t`)
|
||||||
|
default:
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
79
internal/util/toolcalls_json_repair.go
Normal file
79
internal/util/toolcalls_json_repair.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func repairInvalidJSONBackslashes(s string) string {
|
||||||
|
if !strings.Contains(s, "\\") {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
var out strings.Builder
|
||||||
|
out.Grow(len(s) + 10)
|
||||||
|
runes := []rune(s)
|
||||||
|
for i := 0; i < len(runes); i++ {
|
||||||
|
if runes[i] == '\\' {
|
||||||
|
if i+1 < len(runes) {
|
||||||
|
next := runes[i+1]
|
||||||
|
switch next {
|
||||||
|
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't':
|
||||||
|
out.WriteRune('\\')
|
||||||
|
out.WriteRune(next)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
case 'u':
|
||||||
|
if i+5 < len(runes) {
|
||||||
|
isHex := true
|
||||||
|
for j := 1; j <= 4; j++ {
|
||||||
|
r := runes[i+1+j]
|
||||||
|
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
|
||||||
|
isHex = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isHex {
|
||||||
|
out.WriteRune('\\')
|
||||||
|
out.WriteRune('u')
|
||||||
|
for j := 1; j <= 4; j++ {
|
||||||
|
out.WriteRune(runes[i+1+j])
|
||||||
|
}
|
||||||
|
i += 5
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Not a valid escape sequence, double it
|
||||||
|
out.WriteString("\\\\")
|
||||||
|
} else {
|
||||||
|
out.WriteRune(runes[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
var unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`)
|
||||||
|
|
||||||
|
// missingArrayBracketsPattern identifies a sequence of two or more JSON objects separated by commas
|
||||||
|
// that immediately follow a colon, which indicates a missing array bracket `[` `]`.
|
||||||
|
// E.g., "key": {"a": 1}, {"b": 2} -> "key": [{"a": 1}, {"b": 2}]
|
||||||
|
// NOTE: The pattern uses (?:[^{}]|\{[^{}]*\})* to support single-level nested {} objects,
|
||||||
|
// which handles cases like {"content": "x", "input": {"q": "y"}}
|
||||||
|
var missingArrayBracketsPattern = regexp.MustCompile(`(:\s*)(\{(?:[^{}]|\{[^{}]*\})*\}(?:\s*,\s*\{(?:[^{}]|\{[^{}]*\})*\})+)`)
|
||||||
|
|
||||||
|
func RepairLooseJSON(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
// 1. Replace unquoted keys: {key: -> {"key":
|
||||||
|
s = unquotedKeyPattern.ReplaceAllString(s, `$1"$2":`)
|
||||||
|
|
||||||
|
// 2. Heuristic: Fix missing array brackets for list of objects
|
||||||
|
// e.g., : {obj1}, {obj2} -> : [{obj1}, {obj2}]
|
||||||
|
// This specifically addresses DeepSeek's "list hallucination"
|
||||||
|
s = missingArrayBracketsPattern.ReplaceAllString(s, `$1[$2]`)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
161
internal/util/toolcalls_markup.go
Normal file
161
internal/util/toolcalls_markup.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var toolCallMarkupTagNames = []string{"tool_call", "function_call", "invoke"}
|
||||||
|
var toolCallMarkupTagPatternByName = map[string]*regexp.Regexp{
|
||||||
|
"tool_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?tool_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?tool_call>`),
|
||||||
|
"function_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?function_call>`),
|
||||||
|
"invoke": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?invoke>`),
|
||||||
|
}
|
||||||
|
var toolCallMarkupSelfClosingPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)/>`)
|
||||||
|
var toolCallMarkupKVPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)>`)
|
||||||
|
var toolCallMarkupAttrPattern = regexp.MustCompile(`(?is)(name|function|tool)\s*=\s*"([^"]+)"`)
|
||||||
|
var anyTagPattern = regexp.MustCompile(`(?is)<[^>]+>`)
|
||||||
|
var toolCallMarkupNameTagNames = []string{"name", "function"}
|
||||||
|
var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
|
||||||
|
"name": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?name>`),
|
||||||
|
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
|
||||||
|
}
|
||||||
|
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
|
||||||
|
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
|
||||||
|
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
|
||||||
|
"arguments": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?arguments>`),
|
||||||
|
"argument": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?argument\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?argument>`),
|
||||||
|
"parameters": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameters>`),
|
||||||
|
"parameter": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameter>`),
|
||||||
|
"args": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?args\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?args>`),
|
||||||
|
"params": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?params\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?params>`),
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupToolCalls(text string) []ParsedToolCall {
|
||||||
|
trimmed := strings.TrimSpace(text)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]ParsedToolCall, 0)
|
||||||
|
for _, tagName := range toolCallMarkupTagNames {
|
||||||
|
pattern := toolCallMarkupTagPatternByName[tagName]
|
||||||
|
for _, m := range pattern.FindAllStringSubmatch(trimmed, -1) {
|
||||||
|
if len(m) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
attrs := strings.TrimSpace(m[1])
|
||||||
|
inner := strings.TrimSpace(m[2])
|
||||||
|
if parsed := parseMarkupSingleToolCall(attrs, inner); parsed.Name != "" {
|
||||||
|
out = append(out, parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range toolCallMarkupSelfClosingPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||||
|
if len(m) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if parsed := parseMarkupSingleToolCall(strings.TrimSpace(m[1]), ""); parsed.Name != "" {
|
||||||
|
out = append(out, parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||||
|
if parsed := parseToolCallsPayload(inner); len(parsed) > 0 {
|
||||||
|
return parsed[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
name := ""
|
||||||
|
if m := toolCallMarkupAttrPattern.FindStringSubmatch(attrs); len(m) >= 3 {
|
||||||
|
name = strings.TrimSpace(m[2])
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = findMarkupTagValue(inner, toolCallMarkupNameTagNames, toolCallMarkupNamePatternByTag)
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}
|
||||||
|
}
|
||||||
|
|
||||||
|
input := map[string]any{}
|
||||||
|
if argsRaw := findMarkupTagValue(inner, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
|
||||||
|
input = parseMarkupInput(argsRaw)
|
||||||
|
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
|
||||||
|
input = kv
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupInput(raw string) map[string]any {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
return map[string]any{"_raw": stripTagText(raw)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupKVObject(text string) map[string]any {
|
||||||
|
matches := toolCallMarkupKVPattern.FindAllStringSubmatch(strings.TrimSpace(text), -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := map[string]any{}
|
||||||
|
for _, m := range matches {
|
||||||
|
if len(m) < 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(m[1])
|
||||||
|
endKey := strings.TrimSpace(m[3])
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(key, endKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value := strings.TrimSpace(stripTagText(m[2]))
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var jsonValue any
|
||||||
|
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||||
|
out[key] = jsonValue
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripTagText(text string) string {
|
||||||
|
return strings.TrimSpace(anyTagPattern.ReplaceAllString(text, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func findMarkupTagValue(text string, tagNames []string, patternByTag map[string]*regexp.Regexp) string {
|
||||||
|
for _, tag := range tagNames {
|
||||||
|
pattern := patternByTag[tag]
|
||||||
|
if pattern == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m := pattern.FindStringSubmatch(text); len(m) >= 2 {
|
||||||
|
value := strings.TrimSpace(m[1])
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
33
internal/util/toolcalls_name_match.go
Normal file
33
internal/util/toolcalls_name_match.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
|
||||||
|
|
||||||
|
func resolveAllowedToolNameWithLooseMatch(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
|
||||||
|
if _, ok := allowed[name]; ok {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(name))
|
||||||
|
if canonical, ok := allowedCanonical[lower]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
|
||||||
|
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
|
||||||
|
if loose == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for candidateLower, canonical := range allowedCanonical {
|
||||||
|
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -16,7 +16,6 @@ type ToolCallParseResult struct {
|
|||||||
RejectedByPolicy bool
|
RejectedByPolicy bool
|
||||||
RejectedToolNames []string
|
RejectedToolNames []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||||
return ParseToolCallsDetailed(text, availableToolNames).Calls
|
return ParseToolCallsDetailed(text, availableToolNames).Calls
|
||||||
}
|
}
|
||||||
@@ -26,23 +25,36 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
|
|||||||
if strings.TrimSpace(text) == "" {
|
if strings.TrimSpace(text) == "" {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
text = stripFencedCodeBlocks(text)
|
result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
|
||||||
if strings.TrimSpace(text) == "" {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
|
|
||||||
|
|
||||||
candidates := buildToolCallCandidates(text)
|
candidates := buildToolCallCandidates(text)
|
||||||
var parsed []ParsedToolCall
|
var parsed []ParsedToolCall
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
tc := parseToolCallsPayload(candidate)
|
||||||
|
if len(tc) == 0 {
|
||||||
|
tc = parseXMLToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(tc) == 0 {
|
||||||
|
tc = parseMarkupToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(tc) == 0 {
|
||||||
|
tc = parseTextKVToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(tc) > 0 {
|
||||||
parsed = tc
|
parsed = tc
|
||||||
result.SawToolCallSyntax = true
|
result.SawToolCallSyntax = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(parsed) == 0 {
|
if len(parsed) == 0 {
|
||||||
return result
|
parsed = parseXMLToolCalls(text)
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseTextKVToolCalls(text)
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.SawToolCallSyntax = true
|
||||||
}
|
}
|
||||||
|
|
||||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||||
@@ -51,7 +63,6 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
|
|||||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||||
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
|
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
|
||||||
}
|
}
|
||||||
@@ -62,28 +73,42 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string)
|
|||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
if looksLikeToolExampleContext(trimmed) {
|
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
|
||||||
return result
|
candidates := buildToolCallCandidates(trimmed)
|
||||||
}
|
var parsed []ParsedToolCall
|
||||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
|
|
||||||
candidates := []string{trimmed}
|
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
candidate = strings.TrimSpace(candidate)
|
candidate = strings.TrimSpace(candidate)
|
||||||
if candidate == "" {
|
if candidate == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") {
|
parsed = parseToolCallsPayload(candidate)
|
||||||
continue
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseXMLToolCalls(candidate)
|
||||||
}
|
}
|
||||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
if len(parsed) == 0 {
|
||||||
result.SawToolCallSyntax = true
|
parsed = parseMarkupToolCalls(candidate)
|
||||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
}
|
||||||
result.Calls = calls
|
if len(parsed) == 0 {
|
||||||
result.RejectedToolNames = rejectedNames
|
parsed = parseTextKVToolCalls(candidate)
|
||||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
}
|
||||||
return result
|
if len(parsed) > 0 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseXMLToolCalls(trimmed)
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseTextKVToolCalls(trimmed)
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.SawToolCallSyntax = true
|
||||||
|
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||||
|
result.Calls = calls
|
||||||
|
result.RejectedToolNames = rejectedNames
|
||||||
|
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,32 +128,32 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
|||||||
}
|
}
|
||||||
if len(allowed) == 0 {
|
if len(allowed) == 0 {
|
||||||
rejectedSet := map[string]struct{}{}
|
rejectedSet := map[string]struct{}{}
|
||||||
|
rejected := make([]string, 0, len(parsed))
|
||||||
for _, tc := range parsed {
|
for _, tc := range parsed {
|
||||||
if tc.Name == "" {
|
if tc.Name == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if _, ok := rejectedSet[tc.Name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
rejectedSet[tc.Name] = struct{}{}
|
rejectedSet[tc.Name] = struct{}{}
|
||||||
}
|
rejected = append(rejected, tc.Name)
|
||||||
rejected := make([]string, 0, len(rejectedSet))
|
|
||||||
for name := range rejectedSet {
|
|
||||||
rejected = append(rejected, name)
|
|
||||||
}
|
}
|
||||||
return nil, rejected
|
return nil, rejected
|
||||||
}
|
}
|
||||||
out := make([]ParsedToolCall, 0, len(parsed))
|
out := make([]ParsedToolCall, 0, len(parsed))
|
||||||
rejectedSet := map[string]struct{}{}
|
rejectedSet := map[string]struct{}{}
|
||||||
|
rejected := make([]string, 0)
|
||||||
for _, tc := range parsed {
|
for _, tc := range parsed {
|
||||||
if tc.Name == "" {
|
if tc.Name == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
matchedName := ""
|
matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
|
||||||
if _, ok := allowed[tc.Name]; ok {
|
|
||||||
matchedName = tc.Name
|
|
||||||
} else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok {
|
|
||||||
matchedName = canonical
|
|
||||||
}
|
|
||||||
if matchedName == "" {
|
if matchedName == "" {
|
||||||
rejectedSet[tc.Name] = struct{}{}
|
if _, ok := rejectedSet[tc.Name]; !ok {
|
||||||
|
rejectedSet[tc.Name] = struct{}{}
|
||||||
|
rejected = append(rejected, tc.Name)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tc.Name = matchedName
|
tc.Name = matchedName
|
||||||
@@ -137,17 +162,23 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
|||||||
}
|
}
|
||||||
out = append(out, tc)
|
out = append(out, tc)
|
||||||
}
|
}
|
||||||
rejected := make([]string, 0, len(rejectedSet))
|
|
||||||
for name := range rejectedSet {
|
|
||||||
rejected = append(rejected, name)
|
|
||||||
}
|
|
||||||
return out, rejected
|
return out, rejected
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
|
||||||
|
return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical)
|
||||||
|
}
|
||||||
|
|
||||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||||
var decoded any
|
var decoded any
|
||||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||||
return nil
|
// Try to repair backslashes first! Because LLMs often mix these two problems.
|
||||||
|
repaired := repairInvalidJSONBackslashes(payload)
|
||||||
|
// Try loose repair on top of that
|
||||||
|
repaired = RepairLooseJSON(repaired)
|
||||||
|
if err := json.Unmarshal([]byte(repaired), &decoded); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
switch v := decoded.(type) {
|
switch v := decoded.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
@@ -163,6 +194,15 @@ func parseToolCallsPayload(payload string) []ParsedToolCall {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func looksLikeToolCallSyntax(text string) bool {
|
||||||
|
lower := strings.ToLower(text)
|
||||||
|
return strings.Contains(lower, "tool_calls") ||
|
||||||
|
strings.Contains(lower, "<tool_call") ||
|
||||||
|
strings.Contains(lower, "<function_call") ||
|
||||||
|
strings.Contains(lower, "<invoke") ||
|
||||||
|
strings.Contains(lower, "function.name:")
|
||||||
|
}
|
||||||
|
|
||||||
func parseToolCallList(v any) []ParsedToolCall {
|
func parseToolCallList(v any) []ParsedToolCall {
|
||||||
items, ok := v.([]any)
|
items, ok := v.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -215,32 +255,3 @@ func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
|
|||||||
Input: parseToolCallInput(inputRaw),
|
Input: parseToolCallInput(inputRaw),
|
||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseToolCallInput(v any) map[string]any {
|
|
||||||
switch x := v.(type) {
|
|
||||||
case nil:
|
|
||||||
return map[string]any{}
|
|
||||||
case map[string]any:
|
|
||||||
return x
|
|
||||||
case string:
|
|
||||||
raw := strings.TrimSpace(x)
|
|
||||||
if raw == "" {
|
|
||||||
return map[string]any{}
|
|
||||||
}
|
|
||||||
var parsed map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
return map[string]any{"_raw": raw}
|
|
||||||
default:
|
|
||||||
b, err := json.Marshal(x)
|
|
||||||
if err != nil {
|
|
||||||
return map[string]any{}
|
|
||||||
}
|
|
||||||
var parsed map[string]any
|
|
||||||
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
return map[string]any{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
235
internal/util/toolcalls_parse_markup.go
Normal file
235
internal/util/toolcalls_parse_markup.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var xmlToolCallPattern = regexp.MustCompile(`(?is)<tool_call>\s*(.*?)\s*</tool_call>`)
|
||||||
|
var functionCallPattern = regexp.MustCompile(`(?is)<function_call>\s*([^<]+?)\s*</function_call>`)
|
||||||
|
var functionParamPattern = regexp.MustCompile(`(?is)<function\s+parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</function\s+parameter>`)
|
||||||
|
var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?function_call>`)
|
||||||
|
var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?argument>`)
|
||||||
|
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
|
||||||
|
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
|
||||||
|
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
|
||||||
|
|
||||||
|
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||||
|
matches := xmlToolCallPattern.FindAllString(text, -1)
|
||||||
|
out := make([]ParsedToolCall, 0, len(matches)+1)
|
||||||
|
for _, block := range matches {
|
||||||
|
call, ok := parseSingleXMLToolCall(block)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, call)
|
||||||
|
}
|
||||||
|
if len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
if call, ok := parseFunctionCallTagStyle(text); ok {
|
||||||
|
return []ParsedToolCall{call}
|
||||||
|
}
|
||||||
|
if calls := parseAntmlFunctionCallStyles(text); len(calls) > 0 {
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
if call, ok := parseInvokeFunctionCallStyle(text); ok {
|
||||||
|
return []ParsedToolCall{call}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||||
|
inner := strings.TrimSpace(block)
|
||||||
|
inner = strings.TrimPrefix(inner, "<tool_call>")
|
||||||
|
inner = strings.TrimSuffix(inner, "</tool_call>")
|
||||||
|
inner = strings.TrimSpace(inner)
|
||||||
|
if strings.HasPrefix(inner, "{") {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
|
||||||
|
name := strings.TrimSpace(asString(payload["tool"]))
|
||||||
|
if name == "" {
|
||||||
|
name = strings.TrimSpace(asString(payload["tool_name"]))
|
||||||
|
}
|
||||||
|
if name != "" {
|
||||||
|
input := map[string]any{}
|
||||||
|
if params, ok := payload["params"].(map[string]any); ok {
|
||||||
|
input = params
|
||||||
|
} else if params, ok := payload["parameters"].(map[string]any); ok {
|
||||||
|
input = params
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dec := xml.NewDecoder(strings.NewReader(block))
|
||||||
|
name := ""
|
||||||
|
params := map[string]any{}
|
||||||
|
inParams := false
|
||||||
|
inTool := false
|
||||||
|
for {
|
||||||
|
tok, err := dec.Token()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
switch t := tok.(type) {
|
||||||
|
case xml.StartElement:
|
||||||
|
tag := strings.ToLower(t.Name.Local)
|
||||||
|
switch tag {
|
||||||
|
case "tool":
|
||||||
|
inTool = true
|
||||||
|
for _, attr := range t.Attr {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") && strings.TrimSpace(name) == "" {
|
||||||
|
name = strings.TrimSpace(attr.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "parameters":
|
||||||
|
inParams = true
|
||||||
|
case "tool_name", "name":
|
||||||
|
var v string
|
||||||
|
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||||
|
name = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
case "input", "arguments", "argument", "args", "params":
|
||||||
|
var v string
|
||||||
|
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||||
|
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||||
|
for k, vv := range parsed {
|
||||||
|
params[k] = vv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if inParams || inTool {
|
||||||
|
var v string
|
||||||
|
if err := dec.DecodeElement(&v, &t); err == nil {
|
||||||
|
params[t.Name.Local] = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case xml.EndElement:
|
||||||
|
tag := strings.ToLower(t.Name.Local)
|
||||||
|
if tag == "parameters" {
|
||||||
|
inParams = false
|
||||||
|
}
|
||||||
|
if tag == "tool" {
|
||||||
|
inTool = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||||
|
m := functionCallPattern.FindStringSubmatch(text)
|
||||||
|
if len(m) < 2 {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(m[1])
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
input := map[string]any{}
|
||||||
|
for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) {
|
||||||
|
if len(pm) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(pm[1])
|
||||||
|
val := strings.TrimSpace(pm[2])
|
||||||
|
if key != "" {
|
||||||
|
input[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAntmlFunctionCallStyles(text string) []ParsedToolCall {
|
||||||
|
matches := antmlFunctionCallPattern.FindAllStringSubmatch(text, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]ParsedToolCall, 0, len(matches))
|
||||||
|
for _, m := range matches {
|
||||||
|
if call, ok := parseSingleAntmlFunctionCallMatch(m); ok {
|
||||||
|
out = append(out, call)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||||
|
if len(m) < 3 {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(m[1])
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(m[2])
|
||||||
|
input := map[string]any{}
|
||||||
|
if strings.HasPrefix(body, "{") {
|
||||||
|
if err := json.Unmarshal([]byte(body), &input); err == nil {
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 {
|
||||||
|
if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil {
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) {
|
||||||
|
if len(am) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
k := strings.TrimSpace(am[1])
|
||||||
|
v := strings.TrimSpace(am[2])
|
||||||
|
if k != "" {
|
||||||
|
input[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||||
|
m := invokeCallPattern.FindStringSubmatch(text)
|
||||||
|
if len(m) < 3 {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(m[1])
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
input := map[string]any{}
|
||||||
|
for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) {
|
||||||
|
if len(pm) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
k := strings.TrimSpace(pm[1])
|
||||||
|
v := strings.TrimSpace(pm[2])
|
||||||
|
if k != "" {
|
||||||
|
input[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(input) == 0 {
|
||||||
|
if argsRaw := findMarkupTagValue(m[2], toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
|
||||||
|
input = parseMarkupInput(argsRaw)
|
||||||
|
} else if kv := parseMarkupKVObject(m[2]); len(kv) > 0 {
|
||||||
|
input = kv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func asString(v any) string {
|
||||||
|
s, _ := v.(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestParseToolCalls(t *testing.T) {
|
func TestParseToolCalls(t *testing.T) {
|
||||||
text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix`
|
text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix`
|
||||||
@@ -19,8 +22,8 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
func TestParseToolCallsFromFencedJSON(t *testing.T) {
|
func TestParseToolCallsFromFencedJSON(t *testing.T) {
|
||||||
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
|
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
|
||||||
calls := ParseToolCalls(text, []string{"search"})
|
calls := ParseToolCalls(text, []string{"search"})
|
||||||
if len(calls) != 0 {
|
if len(calls) != 1 {
|
||||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
t.Fatalf("expected fenced tool_call payload to be parsed, got %#v", calls)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,10 +99,10 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
func TestParseStandaloneToolCallsSupportsMixedProsePayload(t *testing.T) {
|
||||||
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||||
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 {
|
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 1 {
|
||||||
t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls)
|
t.Fatalf("expected standalone parser to parse mixed prose payload, got %#v", calls)
|
||||||
}
|
}
|
||||||
|
|
||||||
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||||
@@ -109,9 +112,408 @@ func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
func TestParseStandaloneToolCallsParsesFencedCodeBlock(t *testing.T) {
|
||||||
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
|
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
|
||||||
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 {
|
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 1 {
|
||||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
t.Fatalf("expected fenced tool_call payload to be parsed, got %#v", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
|
||||||
|
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
|
||||||
|
calls := ParseToolCalls(text, []string{"search_web"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "search_web" {
|
||||||
|
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
|
||||||
|
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "read_file" {
|
||||||
|
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
|
||||||
|
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command><description>show cwd</description></parameters></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsDetailedMarksXMLToolCallSyntax(t *testing.T) {
|
||||||
|
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`
|
||||||
|
res := ParseToolCallsDetailed(text, []string{"bash"})
|
||||||
|
if !res.SawToolCallSyntax {
|
||||||
|
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||||
|
}
|
||||||
|
if len(res.Calls) != 1 {
|
||||||
|
t.Fatalf("expected one parsed call, got %#v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsClaudeXMLJSONToolCall(t *testing.T) {
|
||||||
|
text := `<tool_call>{"tool":"Bash","params":{"command":"pwd","description":"show cwd"}}</tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsFunctionCallTagStyle(t *testing.T) {
|
||||||
|
text := `<function_call>Bash</function_call><function parameter name="command">ls -la</function parameter><function parameter name="description">list</function parameter>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "ls -la" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsAntmlFunctionCallStyle(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call name="Bash">{"command":"pwd","description":"x"}</antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsAntmlArgumentStyle(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument><antml:argument name="description">x</antml:argument></antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
|
||||||
|
text := `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter><parameter name="description">d</parameter></invoke></function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) {
|
||||||
|
text := `<tool_call><tool name="Bash"><command>pwd</command><description>show cwd</description></tool></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call id="x" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsMultipleAntmlFunctionCalls(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call><antml:function_call id="2" function="Read"><antml:parameters>{"file_path":"README.md"}</antml:parameters></antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash", "read"})
|
||||||
|
if len(calls) != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" || calls[1].Name != "read" {
|
||||||
|
t.Fatalf("expected canonical names [bash read], got %#v", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||||
|
text := `<tool_call><name>read_file</function><arguments>{"path":"README.md"}</arguments></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected mismatched tags to be rejected, got %#v", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepairInvalidJSONBackslashes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{`{"path": "C:\Users\name"}`, `{"path": "C:\\Users\name"}`},
|
||||||
|
{`{"cmd": "cd D:\git_codes"}`, `{"cmd": "cd D:\\git_codes"}`},
|
||||||
|
{`{"text": "line1\nline2"}`, `{"text": "line1\nline2"}`},
|
||||||
|
{`{"path": "D:\\back\\slash"}`, `{"path": "D:\\back\\slash"}`},
|
||||||
|
{`{"unicode": "\u2705"}`, `{"unicode": "\u2705"}`},
|
||||||
|
{`{"invalid_u": "\u123"}`, `{"invalid_u": "\\u123"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := repairInvalidJSONBackslashes(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("repairInvalidJSONBackslashes(%s) = %s; want %s", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepairLooseJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{`{tool_calls: [{"name": "search", "input": {"q": "go"}}]}`, `{"tool_calls": [{"name": "search", "input": {"q": "go"}}]}`},
|
||||||
|
{`{name: "search", input: {q: "go"}}`, `{"name": "search", "input": {"q": "go"}}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := RepairLooseJSON(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("RepairLooseJSON(%s) = %s; want %s", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsWithUnquotedKeys(t *testing.T) {
|
||||||
|
text := `这里是列表:{tool_calls: [{"name": "todowrite", "input": {"todos": "test"}}]}`
|
||||||
|
availableTools := []string{"todowrite"}
|
||||||
|
|
||||||
|
parsed := ParseToolCalls(text, availableTools)
|
||||||
|
if len(parsed) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(parsed))
|
||||||
|
}
|
||||||
|
if parsed[0].Name != "todowrite" {
|
||||||
|
t.Errorf("expected tool todowrite, got %s", parsed[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsWithInvalidBackslashes(t *testing.T) {
|
||||||
|
// DeepSeek sometimes outputs Windows paths with single backslashes in JSON strings
|
||||||
|
// Note: using raw string to simulate what AI actually sends in the stream
|
||||||
|
text := `好的,执行以下命令:{"name": "execute_command", "input": "{\"command\": \"cd D:\git_codes && dir\"}"}`
|
||||||
|
availableTools := []string{"execute_command"}
|
||||||
|
|
||||||
|
parsed := ParseToolCalls(text, availableTools)
|
||||||
|
// If standard JSON fails, buildToolCallCandidates should still extract the object,
|
||||||
|
// and parseToolCallsPayload should repair it.
|
||||||
|
if len(parsed) != 1 {
|
||||||
|
// If it still fails, let's see why
|
||||||
|
candidates := buildToolCallCandidates(text)
|
||||||
|
t.Logf("Candidates: %v", candidates)
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(parsed))
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd, ok := parsed[0].Input["command"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected command string in input, got %v", parsed[0].Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "cd D:\\git_codes && dir"
|
||||||
|
if cmd != expected {
|
||||||
|
t.Errorf("expected command %q, got %q", expected, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsWithDeepSeekHallucination(t *testing.T) {
|
||||||
|
// 模拟 DeepSeek 典型的幻觉输出:未加引号的键名 + 包含 Windows 路径的嵌套 JSON 字符串 + 漏掉列表的方括号
|
||||||
|
text := `检测到实施意图——实现经典算法。需在misc/目录创建Python文件。
|
||||||
|
关键约束:
|
||||||
|
1. Windows UTF-8编码处理
|
||||||
|
2. 必须用绝对路径导入
|
||||||
|
3. 禁止write覆盖已有文件(misc/目录允许创建新文件)
|
||||||
|
将任务分解并委托:
|
||||||
|
- 研究8皇后算法模式(并行探索)
|
||||||
|
- 实现带可视化输出的解决方案(unspecified-high)
|
||||||
|
先创建todo列表追踪步骤。
|
||||||
|
{tool_calls: [{"name": "todowrite", "input": {"todos": {"content": "研究8皇后问题算法模式(回溯法)和输出格式", "status": "pending", "priority": "high"}, {"content": "在misc/目录创建8皇后Python脚本,包含完整解决方案和可视化输出", "status": "pending", "priority": "high"}, {"content": "验证脚本正确性(运行测试)", "status": "pending", "priority": "medium"}}}]}`
|
||||||
|
|
||||||
|
availableTools := []string{"todowrite"}
|
||||||
|
parsed := ParseToolCalls(text, availableTools)
|
||||||
|
|
||||||
|
if len(parsed) != 1 {
|
||||||
|
cands := buildToolCallCandidates(text)
|
||||||
|
for i, c := range cands {
|
||||||
|
t.Logf("CAND %d: %s", i, c)
|
||||||
|
repaired := RepairLooseJSON(c)
|
||||||
|
t.Logf(" REPAIRED: %s", repaired)
|
||||||
|
}
|
||||||
|
t.Fatalf("expected 1 tool call, got %d. Candidates: %v", len(parsed), buildToolCallCandidates(text))
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed[0].Name != "todowrite" {
|
||||||
|
t.Errorf("expected tool name 'todowrite', got %q", parsed[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
todos, ok := parsed[0].Input["todos"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected 'todos' to be parsed as a list, got %T: %#v", parsed[0].Input["todos"], parsed[0].Input["todos"])
|
||||||
|
}
|
||||||
|
if len(todos) != 3 {
|
||||||
|
t.Errorf("expected 3 todo items, got %d", len(todos))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsWithMixedWindowsPaths(t *testing.T) {
|
||||||
|
// 更复杂的案例:嵌套 JSON 字符串中的反斜杠未转义
|
||||||
|
text := `关键约束: 1. Windows UTF-8编码处理 2. 必须用绝对路径导入 D:\git_codes\ds2api\misc
|
||||||
|
{tool_calls: [{"name": "write_file", "input": "{\"path\": \"D:\\git_codes\\ds2api\\misc\\queens.py\", \"content\": \"print('hello')\"}"}]}`
|
||||||
|
|
||||||
|
availableTools := []string{"write_file"}
|
||||||
|
parsed := ParseToolCalls(text, availableTools)
|
||||||
|
|
||||||
|
if len(parsed) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call from mixed text with paths, got %d", len(parsed))
|
||||||
|
}
|
||||||
|
|
||||||
|
path, _ := parsed[0].Input["path"].(string)
|
||||||
|
// 在解析后的 Go map 中,反斜杠应该被还原
|
||||||
|
if !strings.Contains(path, "D:\\git_codes") && !strings.Contains(path, "D:/git_codes") {
|
||||||
|
t.Errorf("expected path to contain Windows style separators, got %q", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallInputRepairsControlCharsInPath(t *testing.T) {
|
||||||
|
in := `{"path":"D:\tmp\new\readme.txt","content":"line1\nline2"}`
|
||||||
|
parsed := parseToolCallInput(in)
|
||||||
|
|
||||||
|
path, ok := parsed["path"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected path string in parsed input, got %#v", parsed["path"])
|
||||||
|
}
|
||||||
|
if path != `D:\tmp\new\readme.txt` {
|
||||||
|
t.Fatalf("expected repaired windows path, got %q", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, ok := parsed["content"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected content string in parsed input, got %#v", parsed["content"])
|
||||||
|
}
|
||||||
|
if content != "line1\nline2" {
|
||||||
|
t.Fatalf("expected non-path field to keep decoded escapes, got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepairLooseJSONWithNestedObjects(t *testing.T) {
|
||||||
|
// 测试嵌套对象的修复:DeepSeek 幻觉输出,每个元素内部包含嵌套 {}
|
||||||
|
// 注意:正则只支持单层嵌套,不支持更深层次的嵌套
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// 1. 单层嵌套对象(核心修复目标)
|
||||||
|
{
|
||||||
|
name: "单层嵌套 - 2个元素",
|
||||||
|
input: `"todos": {"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}`,
|
||||||
|
expected: `"todos": [{"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}]`,
|
||||||
|
},
|
||||||
|
// 2. 3个单层嵌套对象
|
||||||
|
{
|
||||||
|
name: "3个单层嵌套对象",
|
||||||
|
input: `"items": {"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}`,
|
||||||
|
expected: `"items": [{"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}]`,
|
||||||
|
},
|
||||||
|
// 3. 混合嵌套:有些字段是对象,有些是原始值
|
||||||
|
{
|
||||||
|
name: "混合嵌套 - 对象和原始值混合",
|
||||||
|
input: `"items": {"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}`,
|
||||||
|
expected: `"items": [{"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}]`,
|
||||||
|
},
|
||||||
|
// 4. 4个嵌套对象(边界测试)
|
||||||
|
{
|
||||||
|
name: "4个嵌套对象",
|
||||||
|
input: `"todos": {"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}`,
|
||||||
|
expected: `"todos": [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}]`,
|
||||||
|
},
|
||||||
|
// 5. DeepSeek 典型幻觉:无空格逗号分隔
|
||||||
|
{
|
||||||
|
name: "无空格逗号分隔",
|
||||||
|
input: `"results": {"name": "a"}, {"name": "b"}, {"name": "c"}`,
|
||||||
|
expected: `"results": [{"name": "a"}, {"name": "b"}, {"name": "c"}]`,
|
||||||
|
},
|
||||||
|
// 6. 嵌套数组(数组在对象内,不是深层嵌套)
|
||||||
|
{
|
||||||
|
name: "对象内包含数组",
|
||||||
|
input: `"data": {"items": [1,2,3]}, {"items": [4,5,6]}`,
|
||||||
|
expected: `"data": [{"items": [1,2,3]}, {"items": [4,5,6]}]`,
|
||||||
|
},
|
||||||
|
// 7. 真实的 DeepSeek 8皇后问题输出
|
||||||
|
{
|
||||||
|
name: "DeepSeek 8皇后真实输出",
|
||||||
|
input: `"todos": {"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}`,
|
||||||
|
expected: `"todos": [{"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}]`,
|
||||||
|
},
|
||||||
|
// 8. 简单无嵌套对象(回归测试)
|
||||||
|
{
|
||||||
|
name: "简单无嵌套对象",
|
||||||
|
input: `"items": {"a": 1}, {"b": 2}`,
|
||||||
|
expected: `"items": [{"a": 1}, {"b": 2}]`,
|
||||||
|
},
|
||||||
|
// 9. 更复杂的单层嵌套
|
||||||
|
{
|
||||||
|
name: "复杂单层嵌套",
|
||||||
|
input: `"functions": {"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}`,
|
||||||
|
expected: `"functions": [{"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}]`,
|
||||||
|
},
|
||||||
|
// 10. 5个嵌套对象
|
||||||
|
{
|
||||||
|
name: "5个嵌套对象",
|
||||||
|
input: `"tasks": {"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}`,
|
||||||
|
expected: `"tasks": [{"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := RepairLooseJSON(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("[%s] RepairLooseJSON with nested objects:\n input: %s\n got: %s\n expected: %s", tt.name, tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
55
internal/util/toolcalls_textkv.go
Normal file
55
internal/util/toolcalls_textkv.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var textKVNamePattern = regexp.MustCompile(`(?is)function\.name:\s*([a-zA-Z0-9_\-.]+)`)
|
||||||
|
|
||||||
|
func parseTextKVToolCalls(text string) []ParsedToolCall {
|
||||||
|
var out []ParsedToolCall
|
||||||
|
matches := textKVNamePattern.FindAllStringSubmatchIndex(text, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, match := range matches {
|
||||||
|
name := text[match[2]:match[3]]
|
||||||
|
|
||||||
|
offset := match[1]
|
||||||
|
endSearch := len(text)
|
||||||
|
if i+1 < len(matches) {
|
||||||
|
endSearch = matches[i+1][0]
|
||||||
|
}
|
||||||
|
|
||||||
|
searchArea := text[offset:endSearch]
|
||||||
|
argIdx := strings.Index(searchArea, "function.arguments:")
|
||||||
|
if argIdx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
startIdx := offset + argIdx + len("function.arguments:")
|
||||||
|
braceIdx := strings.IndexByte(text[startIdx:endSearch], '{')
|
||||||
|
if braceIdx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
actualStart := startIdx + braceIdx
|
||||||
|
objJson, _, ok := extractJSONObject(text, actualStart)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
input := parseToolCallInput(objJson)
|
||||||
|
out = append(out, ParsedToolCall{
|
||||||
|
Name: name,
|
||||||
|
Input: input,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
63
internal/util/toolcalls_textkv_test.go
Normal file
63
internal/util/toolcalls_textkv_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseTextKVToolCalls_Basic(t *testing.T) {
|
||||||
|
text := `
|
||||||
|
[TOOL_CALL_HISTORY]
|
||||||
|
status: already_called
|
||||||
|
origin: assistant
|
||||||
|
not_user_input: true
|
||||||
|
tool_call_id: call_3fcd15235eb94f7eae3a8de5a9cfa36b
|
||||||
|
function.name: execute_command
|
||||||
|
function.arguments: {"command":"cd scripts && python check_syntax.py example.py","cwd":null,"timeout":30}
|
||||||
|
[/TOOL_CALL_HISTORY]
|
||||||
|
|
||||||
|
Some other text thinking...
|
||||||
|
`
|
||||||
|
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Name != "execute_command" {
|
||||||
|
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "cd scripts && python check_syntax.py example.py" {
|
||||||
|
t.Fatalf("unexpected command arg: %v", calls[0].Input["command"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextKVToolCalls_Multiple(t *testing.T) {
|
||||||
|
text := `
|
||||||
|
function.name: read_file
|
||||||
|
function.arguments: {
|
||||||
|
"path": "abc.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
function.name: bash
|
||||||
|
function.arguments: {"command": "ls"}
|
||||||
|
`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file", "bash"})
|
||||||
|
if len(calls) != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Name != "read_file" {
|
||||||
|
t.Fatalf("unexpected 1st name: %s", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[1].Name != "bash" {
|
||||||
|
t.Fatalf("unexpected 2nd name: %s", calls[1].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextKVToolCalls_Standalone(t *testing.T) {
|
||||||
|
text := "function.name: read_file\nfunction.arguments: {\"path\":\"README.md\"}"
|
||||||
|
calls := ParseStandaloneToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Name != "read_file" {
|
||||||
|
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -409,8 +409,8 @@ func TestParseToolCallsWithFunctionWrapper(t *testing.T) {
|
|||||||
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
|
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
|
||||||
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
|
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
|
||||||
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
|
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
|
||||||
if len(calls) != 0 {
|
if len(calls) != 1 {
|
||||||
t.Fatalf("expected fenced code block ignored, got %d calls", len(calls))
|
t.Fatalf("expected fenced code block to be parsed, got %d calls", len(calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
101
misc/deepseek_functioncalling_bug/report.md
Normal file
101
misc/deepseek_functioncalling_bug/report.md
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# DeepSeek Function Calling 缺陷分析与 ds2api 的增强修复策略
|
||||||
|
|
||||||
|
> **相关 PR**: #74 (代码核心实现) 与 #75 (Merge to dev)
|
||||||
|
> **问题背景**: 解决因包括 DeepSeek 在内的部分模型在函数调用(Function Calling/Tool Call)表现不够“规范”,从而导致工具调用失败的问题。
|
||||||
|
|
||||||
|
## 一、底层架构对比:为什么会产生 Function Calling 缺陷?
|
||||||
|
|
||||||
|
在探讨缺陷前,我们需要理解两种 Function Calling 的底层结构差异:
|
||||||
|
|
||||||
|
### 1. OpenAI 的原生结构化返回 (API 级分离)
|
||||||
|
在 OpenAI 的规范中,**聊天文字与工具调用是在底层的 JSON 结构中被硬性拆分的**:
|
||||||
|
* 聊天废话存放在 `response.choices[0].message.content` 里。
|
||||||
|
* 工具请求存放在单独的数组 `response.choices[0].message.tool_calls` 里。
|
||||||
|
|
||||||
|
**优势:** 这种设计对客户端极其友好。客户端只需判断 `tool_calls` 是否为空,就能决定是执行代码还是渲染文字。它支持同时并发多个工具请求,且底层的生成殷勤被严格训练和约束,极少抛出语法错误的 JSON。
|
||||||
|
|
||||||
|
### 2. DeepSeek 等模型的“单文本流”机制
|
||||||
|
相比之下,部分未经深度专门微调的模型(或者在特定的通信适配层中),它们依然倾向于把一切内容打包成一个纯文本流吐出。这就是为什么它们的输出往往不仅包含了本该属于 `tool_calls` 结构里的 JSON,还会像个“老实人”一样夹杂了属于 `content` 里的散文。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、DeepSeek 在 Function Calling 上的特定缺陷表现
|
||||||
|
|
||||||
|
相比于 OpenAI 严格遵循 API 约定的原生结构,DeepSeek 等开源/国产推理模型在工具调用时,经常会暴露出以下三种典型的“不守规矩”的输出行为:
|
||||||
|
|
||||||
|
### 1. 混合输出:散文文本与工具 JSON 混杂 (Mixed Prose Streams)
|
||||||
|
当应用要求模型直接返回工具请求时,DeepSeek 有时候会**“忍不住想和用户搭话”**。
|
||||||
|
它常常前置一段解释性废话,中间插入工具调用的 JSON 参数,并在末尾再补上一句总结:
|
||||||
|
```text
|
||||||
|
好的,我这就帮你读取 README.md 的内容:
|
||||||
|
{"tool_calls":[{"name":"read_file","input":{"path":"README.md"}}]}
|
||||||
|
请稍等片刻,我马上把它读出来。
|
||||||
|
```
|
||||||
|
**旧版系统痛点:**
|
||||||
|
原有的代码存在**严格模式(Strict Mode)**校验:
|
||||||
|
```go
|
||||||
|
// 如果解析到的 JSON 块前后存在任何非空字符串,就放弃当作工具调用!
|
||||||
|
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" ... {
|
||||||
|
return captured, nil, "", true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
这直接导致上述结构被网关认定是一段“普通聊天”,直接原封不动地返回给用户,这直接干挂了后续的工具自动执行流程。
|
||||||
|
|
||||||
|
### 2. 工具名格式幻觉:擅自修改或前缀化工具名称
|
||||||
|
由于 DeepSeek 的预训练数据中有大量的代码和不同的平台结构,它在回复工具名称时,常常无法忠实于 System Prompt 中提供的纯命名(也就是 `name: "read_file"`),而是加上前缀或者拼写变形,例如:
|
||||||
|
* `{"name": "mcp.search_web"}` (自带命名空间)
|
||||||
|
* `{"name": "tools.read_file"}`
|
||||||
|
* `{"name": "search-web"}` (下划线变成了中划线)
|
||||||
|
|
||||||
|
**旧版系统痛点:**
|
||||||
|
旧版系统对于工具名的匹配几乎只有“绝对相等”的字典级比对,只要差了一个字符或加了前缀,就会由于找不到合法工具而直接失败。
|
||||||
|
|
||||||
|
### 3. Role 角色的非标准返回
|
||||||
|
在部分工具通信流的响应中,返回的内容其所属的 `role` 没有被标准化处理,可能携带意料之外的属性,或是与下游严格比对出现冲突。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、PR #74 的代码增强修复方案
|
||||||
|
|
||||||
|
为了解决大模型这种自身的不规范行为,PR #74 在系统的中间层网关联入了一个**极其包容的容错引擎**。它并不强制要求模型“改过自新”,而是主动做了以下三块增强:
|
||||||
|
|
||||||
|
### 1. 从流中分离混合内容(废除 Strict Mode)
|
||||||
|
修改了 `internal/adapter/openai/tool_sieve_core.go`。
|
||||||
|
取消了前后包裹文本的拦截逻辑。当系统扫描到流式结构中有完整的 `{"tool_calls":...}` 时,它会将废话和 JSON 分发到不同的事件流中:
|
||||||
|
```go
|
||||||
|
if prefix != "" {
|
||||||
|
// 将前面的“好的,帮你读文件”剥离出来作为常规文本输出
|
||||||
|
state.noteText(prefix)
|
||||||
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
|
}
|
||||||
|
// 捕获并拦截中间的工具请求,进行背后执行
|
||||||
|
state.pendingToolCalls = calls
|
||||||
|
```
|
||||||
|
**效果:** 用户的屏幕上只能看到正常的文字交流,而后端的工具也会立刻挂载。
|
||||||
|
|
||||||
|
### 2. 多级宽容匹配引擎 (Resolve Allowed Tool Name)
|
||||||
|
在 `internal/util/toolcalls_parse.go` 中,新增了一个由严到松降级匹配的强大漏斗策略函数 `resolveAllowedToolName`:
|
||||||
|
|
||||||
|
1. **绝对匹配**:和以前一样,`read_file` == `read_file`。
|
||||||
|
2. **忽略大小写**:`Read_File` 算作合法。
|
||||||
|
3. **命名空间抹除**:通过寻找最后一个 `.` 来剥离前缀,强制将 `mcp.search_web` 还原出真实的 `search_web`。
|
||||||
|
4. **终极正则清洗**:
|
||||||
|
引入 `var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)`。
|
||||||
|
这个正则剥离了字符串里所有的符号、空格、格式符。
|
||||||
|
将传入的 `read-file` 洗除符号成为 `readfile`,并去和系统中所有合法工具同样清洗后的版本进行比较。只要核心字母一致,即算作匹配成功。
|
||||||
|
|
||||||
|
### 3. Role 归一化 (Normalize OpenAIRoleForPrompt)
|
||||||
|
在 `internal/adapter/openai/responses_input_items.go` 等处,引入了特定的 `normalizeOpenAIRoleForPrompt(role)` 清洗,保证输入和传递给上游的 Role 枚举始终受控,消除了因为意外的身份字段传参崩溃。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 报告总结与 tool_sieve 的本质作用
|
||||||
|
|
||||||
|
PR #74 / #75 并没有从模型本身开刀,而是基于**网关应足够健壮**的设计哲学。
|
||||||
|
|
||||||
|
**其实整个增强实现,本质上实现了一个名为 `tool_sieve` (工具筛子) 的中间层网关。**
|
||||||
|
面对 DeepSeek 这种吐出一团混合了聊天文字与 JSON 面团的“不标准”数据流,`tool_sieve` 就像一个勤劳的高精度筛子,不仅人工揉开了面团:
|
||||||
|
1. 它把散文分拣出来,塞回标准结构的 `content` 字段去展示;
|
||||||
|
2. 剥离并清洗出有瑕疵的 JSON 块,按照 OpenAI 的标准格式小心翼翼地放进 `tool_calls` 结构里去等待执行。
|
||||||
|
|
||||||
|
这意味着,即便 AI 被配置了奇怪的回复设定、加粗了强调语言,甚至是犯了标点符号拼写小失误,**只要它输出了可以拼凑成工具指令的 JSON 核心单元,整个中继层就能将其挽救,并把正确的工具结果呈现给模型和用户**。 这不仅修复了缺陷,更极大地增强了工具网关的通用性和鲁棒性。
|
||||||
@@ -16,7 +16,6 @@ internal/js/helpers/stream-tool-sieve.js
|
|||||||
internal/js/helpers/stream-tool-sieve/index.js
|
internal/js/helpers/stream-tool-sieve/index.js
|
||||||
internal/js/helpers/stream-tool-sieve/state.js
|
internal/js/helpers/stream-tool-sieve/state.js
|
||||||
internal/js/helpers/stream-tool-sieve/sieve.js
|
internal/js/helpers/stream-tool-sieve/sieve.js
|
||||||
internal/js/helpers/stream-tool-sieve/incremental.js
|
|
||||||
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
||||||
internal/js/helpers/stream-tool-sieve/parse.js
|
internal/js/helpers/stream-tool-sieve/parse.js
|
||||||
internal/js/helpers/stream-tool-sieve/format.js
|
internal/js/helpers/stream-tool-sieve/format.js
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
# Line gate targets for large-file decoupling refactor.
|
# Line gate targets for large-file decoupling refactor.
|
||||||
# Default limit: 300 lines
|
# Backend default limit: 300 lines
|
||||||
|
# Frontend (webui/) default limit: 500 lines
|
||||||
# Entry/facade limit: 120 lines (enforced in script)
|
# Entry/facade limit: 120 lines (enforced in script)
|
||||||
|
# Test files are ignored by the gate script.
|
||||||
|
|
||||||
internal/config/config.go
|
internal/config/config.go
|
||||||
internal/config/logger.go
|
internal/config/logger.go
|
||||||
@@ -105,7 +107,6 @@ internal/js/helpers/stream-tool-sieve.js
|
|||||||
internal/js/helpers/stream-tool-sieve/index.js
|
internal/js/helpers/stream-tool-sieve/index.js
|
||||||
internal/js/helpers/stream-tool-sieve/state.js
|
internal/js/helpers/stream-tool-sieve/state.js
|
||||||
internal/js/helpers/stream-tool-sieve/sieve.js
|
internal/js/helpers/stream-tool-sieve/sieve.js
|
||||||
internal/js/helpers/stream-tool-sieve/incremental.js
|
|
||||||
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
||||||
internal/js/helpers/stream-tool-sieve/parse.js
|
internal/js/helpers/stream-tool-sieve/parse.js
|
||||||
internal/js/helpers/stream-tool-sieve/format.js
|
internal/js/helpers/stream-tool-sieve/format.js
|
||||||
|
|||||||
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
## Rules
|
## Rules
|
||||||
|
|
||||||
1. Production file default upper bound: `<= 300` lines.
|
1. Backend production files upper bound: `<= 300` lines.
|
||||||
2. Entry/facade files upper bound: `<= 120` lines.
|
2. Frontend (`webui/`) production files upper bound: `<= 500` lines.
|
||||||
3. Scope is limited to target files in `plans/refactor-line-gate-targets.txt`.
|
3. Entry/facade files upper bound: `<= 120` lines.
|
||||||
4. Test files are out of scope for this gate.
|
4. Scope is limited to target files in `plans/refactor-line-gate-targets.txt`.
|
||||||
|
5. Test files are out of scope for this gate.
|
||||||
|
|
||||||
## Command
|
## Command
|
||||||
|
|
||||||
|
|||||||
8
tests/compat/expected/toolcalls_allowlist_empty.json
Normal file
8
tests/compat/expected/toolcalls_allowlist_empty.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"calls": [],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": true,
|
||||||
|
"rejectedToolNames": [
|
||||||
|
"unknown_tool"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -1,3 +1,13 @@
|
|||||||
{
|
{
|
||||||
"calls": []
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
}
|
}
|
||||||
|
|||||||
13
tests/compat/expected/toolcalls_function_call_tag.json
Normal file
13
tests/compat/expected/toolcalls_function_call_tag.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_invoke_attr.json
Normal file
13
tests/compat/expected/toolcalls_invoke_attr.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_loose_normalize.json
Normal file
13
tests/compat/expected/toolcalls_loose_normalize.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user