mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2674487c7 | ||
|
|
71cdcb43e8 | ||
|
|
9c46c3a874 | ||
|
|
12d5f136d5 | ||
|
|
00c37d8d2f | ||
|
|
0f1985af4a | ||
|
|
fa8affe1b7 | ||
|
|
c59a0b7799 | ||
|
|
bd72b91f27 | ||
|
|
9240f85246 | ||
|
|
ea4bd1e483 | ||
|
|
9e0de62707 | ||
|
|
128de290db | ||
|
|
286d266723 | ||
|
|
8aad1005b2 | ||
|
|
11b2f24fc2 | ||
|
|
d1f08cbb89 | ||
|
|
60e9d707d4 | ||
|
|
9b93badb57 | ||
|
|
892213071a | ||
|
|
5484d6e59d | ||
|
|
0ce3fd22a7 | ||
|
|
25e40cc3a6 | ||
|
|
af68d21095 | ||
|
|
1fafd25e86 | ||
|
|
5f8f28a943 | ||
|
|
94cf1bfcc7 | ||
|
|
13562cf521 | ||
|
|
d27e700c4f | ||
|
|
d6bce5af93 | ||
|
|
75969e710d | ||
|
|
6c39c8e191 | ||
|
|
0e261ff0a0 | ||
|
|
fab326eca1 | ||
|
|
c033eceee7 | ||
|
|
a10e03ebe0 | ||
|
|
a6aa4a1839 | ||
|
|
1c749b6803 | ||
|
|
c329bf26b6 | ||
|
|
3ae5b57ebe | ||
|
|
0bf5d5440c | ||
|
|
d731a1fd4f | ||
|
|
93e9fb531d | ||
|
|
6daeb2553d | ||
|
|
321b8a89ee | ||
|
|
d84875e466 | ||
|
|
ea8c9a28a9 | ||
|
|
a302fb3c25 |
7
API.md
7
API.md
@@ -284,6 +284,11 @@ data: [DONE]
|
|||||||
|
|
||||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||||
|
|
||||||
|
补充说明:
|
||||||
|
|
||||||
|
- **非代码块上下文**下,工具 JSON 即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||||
|
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### `GET /v1/models/{id}`
|
### `GET /v1/models/{id}`
|
||||||
@@ -301,7 +306,7 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
|||||||
| `messages` | array | ❌ | 与 `input` 二选一 |
|
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||||
| `stream` | boolean | ❌ | 默认 `false` |
|
| `stream` | boolean | ❌ | 默认 `false` |
|
||||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略(含代码块示例豁免) |
|
||||||
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||||
|
|
||||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||||
|
|||||||
11
README.MD
11
README.MD
@@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/ds2api-icon.svg" width="128" height="128" alt="DS2API icon" />
|
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# DS2API
|
# DS2API
|
||||||
@@ -10,6 +10,7 @@
|
|||||||
[](https://github.com/CJackHwang/ds2api/releases)
|
[](https://github.com/CJackHwang/ds2api/releases)
|
||||||
[](DEPLOY.md)
|
[](DEPLOY.md)
|
||||||
[](https://zeabur.com/templates/L4CFHP)
|
[](https://zeabur.com/templates/L4CFHP)
|
||||||
|
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||||
|
|
||||||
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
||||||
|
|
||||||
@@ -105,6 +106,14 @@ flowchart LR
|
|||||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||||
|
|
||||||
|
|
||||||
|
#### Claude Code 接入避坑(实测)
|
||||||
|
|
||||||
|
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||||
|
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||||
|
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||||
|
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||||
|
|
||||||
### Gemini 接口
|
### Gemini 接口
|
||||||
|
|
||||||
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||||
|
|||||||
12
README.en.md
12
README.en.md
@@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/ds2api-icon.svg" width="128" height="128" alt="DS2API icon" />
|
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# DS2API
|
# DS2API
|
||||||
@@ -10,6 +10,7 @@
|
|||||||
[](https://github.com/CJackHwang/ds2api/releases)
|
[](https://github.com/CJackHwang/ds2api/releases)
|
||||||
[](DEPLOY.en.md)
|
[](DEPLOY.en.md)
|
||||||
[](https://zeabur.com/templates/L4CFHP)
|
[](https://zeabur.com/templates/L4CFHP)
|
||||||
|
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||||
|
|
||||||
Language: [中文](README.MD) | [English](README.en.md)
|
Language: [中文](README.MD) | [English](README.en.md)
|
||||||
|
|
||||||
@@ -105,6 +106,14 @@ flowchart LR
|
|||||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||||
|
|
||||||
|
|
||||||
|
#### Claude Code integration pitfalls (validated)
|
||||||
|
|
||||||
|
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||||
|
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||||
|
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||||
|
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||||
|
|
||||||
### Gemini Endpoint
|
### Gemini Endpoint
|
||||||
|
|
||||||
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
||||||
@@ -350,6 +359,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
|||||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||||
|
|
||||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||||
|
- In non-code-block context, tool JSON may still be recognized even when mixed with normal prose; surrounding prose can remain as text output.
|
||||||
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||||
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
||||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ DS2API 提供两个层级的测试:
|
|||||||
1. **Preflight 检查**:
|
1. **Preflight 检查**:
|
||||||
- `go test ./... -count=1`(单元测试)
|
- `go test ./... -count=1`(单元测试)
|
||||||
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测)
|
- `node --test`(如仓库存在 Node 单测文件时执行;当前默认以 Go 测试 + Node 语法门禁为主)
|
||||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||||
|
|
||||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="DS2API icon">
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="bg" x1="96" y1="96" x2="416" y2="416" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop offset="0" stop-color="#06162D" />
|
|
||||||
<stop offset="0.6" stop-color="#0A3A6A" />
|
|
||||||
<stop offset="1" stop-color="#00B4D8" />
|
|
||||||
</linearGradient>
|
|
||||||
<radialGradient id="glow" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(256 180) rotate(90) scale(260)">
|
|
||||||
<stop offset="0" stop-color="#FFFFFF" stop-opacity="0.18" />
|
|
||||||
<stop offset="1" stop-color="#FFFFFF" stop-opacity="0" />
|
|
||||||
</radialGradient>
|
|
||||||
<linearGradient id="whale" x1="180" y1="140" x2="360" y2="360" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop offset="0" stop-color="#EAF7FF" />
|
|
||||||
<stop offset="1" stop-color="#BDEBFF" />
|
|
||||||
</linearGradient>
|
|
||||||
</defs>
|
|
||||||
|
|
||||||
<circle cx="256" cy="256" r="240" fill="url(#bg)" />
|
|
||||||
<circle cx="256" cy="256" r="240" fill="url(#glow)" />
|
|
||||||
<circle cx="256" cy="256" r="240" stroke="#FFFFFF" stroke-opacity="0.14" stroke-width="8" />
|
|
||||||
|
|
||||||
<!-- subtle waves -->
|
|
||||||
<path d="M104 338 C156 308 204 366 256 334 C308 302 356 360 408 330" stroke="#FFFFFF" stroke-opacity="0.16" stroke-width="12" stroke-linecap="round" />
|
|
||||||
<path d="M124 372 C174 344 212 396 256 372 C300 348 338 396 388 368" stroke="#FFFFFF" stroke-opacity="0.12" stroke-width="10" stroke-linecap="round" />
|
|
||||||
|
|
||||||
<!-- whale tail (DeepSeek-inspired element, original design) -->
|
|
||||||
<path
|
|
||||||
d="M256 162
|
|
||||||
C228 124 184 118 156 146
|
|
||||||
C132 170 138 206 162 230
|
|
||||||
C190 262 230 252 252 220
|
|
||||||
C254 218 255 216 256 214
|
|
||||||
C257 216 258 218 260 220
|
|
||||||
C282 252 322 262 350 230
|
|
||||||
C374 206 380 170 356 146
|
|
||||||
C328 118 284 124 256 162 Z"
|
|
||||||
fill="url(#whale)"
|
|
||||||
/>
|
|
||||||
<rect x="236" y="214" width="40" height="168" rx="20" fill="url(#whale)" />
|
|
||||||
|
|
||||||
<!-- API nodes -->
|
|
||||||
<g opacity="0.55" stroke="#FFFFFF" stroke-opacity="0.35" stroke-width="6" stroke-linecap="round">
|
|
||||||
<path d="M156 236 L208 206" />
|
|
||||||
<path d="M356 236 L304 206" />
|
|
||||||
<path d="M208 206 L232 172" />
|
|
||||||
<circle cx="156" cy="236" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="208" cy="206" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="232" cy="172" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="304" cy="206" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
<circle cx="356" cy="236" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
|
||||||
</g>
|
|
||||||
|
|
||||||
<!-- tiny sparkle -->
|
|
||||||
<path
|
|
||||||
d="M378 164
|
|
||||||
C372 170 366 174 358 176
|
|
||||||
C366 178 372 182 378 188
|
|
||||||
C380 180 384 176 392 176
|
|
||||||
C384 174 380 170 378 164 Z"
|
|
||||||
fill="#FFFFFF"
|
|
||||||
fill-opacity="0.32"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 2.7 KiB |
@@ -1,6 +1,6 @@
|
|||||||
services:
|
services:
|
||||||
ds2api:
|
ds2api:
|
||||||
image: crpi-cnazxqmg4avmg4fq.cn-beijing.personal.cr.aliyuncs.com/ronghuaxueleng/ds2api:latest
|
image: ghcr.io/cjackhwang/ds2api:latest
|
||||||
container_name: ds2api
|
container_name: ds2api
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
41
docs/toolcall-semantics.md
Normal file
41
docs/toolcall-semantics.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Tool call parsing semantics (Go canonical spec)
|
||||||
|
|
||||||
|
This document defines the cross-runtime contract for `ParseToolCallsDetailed` / `parseToolCallsDetailed`.
|
||||||
|
|
||||||
|
## Output contract
|
||||||
|
|
||||||
|
- `calls`: accepted tool calls with normalized tool names.
|
||||||
|
- `sawToolCallSyntax`: true when tool-call-like syntax is detected (`tool_calls`, `<tool_call>`, `<function_call>`, `<invoke>`) or a valid call is parsed.
|
||||||
|
- `rejectedByPolicy`: true when parser extracted call syntax but all calls are rejected by allow-list policy.
|
||||||
|
- `rejectedToolNames`: de-duplicated rejected tool names in first-seen order.
|
||||||
|
|
||||||
|
## Parse pipeline
|
||||||
|
|
||||||
|
1. Strip fenced code blocks for non-standalone parsing.
|
||||||
|
2. Build candidates from:
|
||||||
|
- full text,
|
||||||
|
- fenced JSON snippets,
|
||||||
|
- extracted JSON objects around `tool_calls`,
|
||||||
|
- first `{` to last `}` object slice.
|
||||||
|
3. Parse each candidate in order:
|
||||||
|
- JSON payload parser (`tool_calls`, list, single call object),
|
||||||
|
- XML/Markup parser (`<tool_call>`, `<function_call>`, `<invoke>`; supports attributes + nested fields),
|
||||||
|
- Text KV fallback parser (`function.name: <name>` ... `function.arguments: {json}`).
|
||||||
|
4. Stop at first candidate that yields at least one call.
|
||||||
|
|
||||||
|
## Name normalization policy
|
||||||
|
|
||||||
|
When matching parsed names against configured tools:
|
||||||
|
|
||||||
|
1. exact match,
|
||||||
|
2. case-insensitive match,
|
||||||
|
3. namespace tail match (`a.b.c` => `c`),
|
||||||
|
4. loose alnum match (remove non `[a-z0-9]`, compare).
|
||||||
|
|
||||||
|
## Standalone mode
|
||||||
|
|
||||||
|
Standalone mode (`ParseStandaloneToolCallsDetailed`) parses the whole input directly (no candidate slicing), while still applying:
|
||||||
|
|
||||||
|
- example-context guard,
|
||||||
|
- JSON then markup fallback,
|
||||||
|
- the same allow-list normalization policy.
|
||||||
@@ -315,3 +315,78 @@ func asString(v any) string {
|
|||||||
s, _ := v.(string)
|
s, _ := v.(string)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload string
|
||||||
|
}{
|
||||||
|
{name: "xml_tool_call", payload: `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`},
|
||||||
|
{name: "xml_json_tool_call", payload: `<tool_call>{"tool":"Bash","params":{"command":"pwd"}}</tool_call>`},
|
||||||
|
{name: "nested_tool_tag_style", payload: `<tool_call><tool name="Bash"><command>pwd</command></tool></tool_call>`},
|
||||||
|
{name: "function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`},
|
||||||
|
{name: "antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`},
|
||||||
|
{name: "antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`},
|
||||||
|
{name: "invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeClaudeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||||
|
|
||||||
|
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||||
|
|
||||||
|
frames := parseClaudeFrames(t, rec.Body.String())
|
||||||
|
foundToolUse := false
|
||||||
|
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||||
|
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||||
|
if contentBlock["type"] == "tool_use" {
|
||||||
|
foundToolUse = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundToolUse {
|
||||||
|
t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleClaudeStreamRealtimeDoesNotStopOnUnclosedFencedToolExample(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeClaudeSSEHTTPResponse(
|
||||||
|
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||||
|
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||||
|
|
||||||
|
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||||
|
|
||||||
|
frames := parseClaudeFrames(t, rec.Body.String())
|
||||||
|
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||||
|
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||||
|
if contentBlock["type"] == "tool_use" {
|
||||||
|
t.Fatalf("unexpected tool_use for fenced example, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
foundEndTurn := false
|
||||||
|
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||||
|
delta, _ := f.Payload["delta"].(map[string]any)
|
||||||
|
if delta["stop_reason"] == "end_turn" {
|
||||||
|
foundEndTurn = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundEndTurn {
|
||||||
|
t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -125,8 +125,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
|||||||
if !containsStr(prompt, "Search the web") {
|
if !containsStr(prompt, "Search the web") {
|
||||||
t.Fatalf("expected description in prompt")
|
t.Fatalf("expected description in prompt")
|
||||||
}
|
}
|
||||||
if !containsStr(prompt, "tool_calls") {
|
if !containsStr(prompt, "tool_use") {
|
||||||
t.Fatalf("expected tool_calls instruction in prompt")
|
t.Fatalf("expected tool_use instruction in prompt")
|
||||||
|
}
|
||||||
|
if containsStr(prompt, "tool_calls") {
|
||||||
|
t.Fatalf("expected prompt to avoid tool_calls JSON instruction")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func buildClaudeToolPrompt(tools []any) string {
|
|||||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||||
}
|
}
|
||||||
parts = append(parts,
|
parts = append(parts,
|
||||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
"When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
|
||||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
streamengine "ds2api/internal/stream"
|
streamengine "ds2api/internal/stream"
|
||||||
|
"ds2api/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeStreamRuntime struct {
|
type claudeStreamRuntime struct {
|
||||||
@@ -116,6 +117,18 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
|||||||
|
|
||||||
s.text.WriteString(p.Text)
|
s.text.WriteString(p.Text)
|
||||||
if s.bufferToolContent {
|
if s.bufferToolContent {
|
||||||
|
if hasUnclosedCodeFence(s.text.String()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
|
||||||
|
if len(detected) > 0 {
|
||||||
|
s.finalize("tool_use")
|
||||||
|
return streamengine.ParsedDecision{
|
||||||
|
ContentSeen: true,
|
||||||
|
Stop: true,
|
||||||
|
StopReason: streamengine.StopReason("tool_use_detected"),
|
||||||
|
}
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.closeThinkingBlock()
|
s.closeThinkingBlock()
|
||||||
@@ -144,3 +157,7 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
|||||||
|
|
||||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasUnclosedCodeFence(text string) bool {
|
||||||
|
return strings.Count(text, "```")%2 == 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
|||||||
|
|
||||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||||
upstream := makeGeminiUpstreamResponse(
|
upstream := makeGeminiUpstreamResponse(
|
||||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||||
`data: [DONE]`,
|
`data: [DONE]`,
|
||||||
)
|
)
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||||
|
upstream := makeGeminiUpstreamResponse(
|
||||||
|
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
body := `{
|
||||||
|
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||||
|
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||||
|
}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
candidates, _ := out["candidates"].([]any)
|
||||||
|
c0, _ := candidates[0].(map[string]any)
|
||||||
|
content, _ := c0["content"].(map[string]any)
|
||||||
|
parts, _ := content["parts"].([]any)
|
||||||
|
part0, _ := parts[0].(map[string]any)
|
||||||
|
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||||
|
if functionCall["name"] != "eval_javascript" {
|
||||||
|
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||||
upstream := makeGeminiUpstreamResponse(
|
upstream := makeGeminiUpstreamResponse(
|
||||||
`data: {"p":"response/content","v":"hello "}`,
|
`data: {"p":"response/content","v":"hello "}`,
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (s *chatStreamRuntime) sendDone() {
|
|||||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||||
finalThinking := s.thinking.String()
|
finalThinking := s.thinking.String()
|
||||||
finalText := s.text.String()
|
finalText := s.text.String()
|
||||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
delta := map[string]any{
|
delta := map[string]any{
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -210,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||||
@@ -228,16 +229,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
|||||||
out := decodeJSONBody(t, rec.Body.String())
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
choices, _ := out["choices"].([]any)
|
choices, _ := out["choices"].([]any)
|
||||||
choice, _ := choices[0].(map[string]any)
|
choice, _ := choices[0].(map[string]any)
|
||||||
if choice["finish_reason"] != "tool_calls" {
|
if choice["finish_reason"] != "stop" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||||
}
|
}
|
||||||
msg, _ := choice["message"].(map[string]any)
|
msg, _ := choice["message"].(map[string]any)
|
||||||
toolCalls, _ := msg["tool_calls"].([]any)
|
if _, ok := msg["tool_calls"]; ok {
|
||||||
if len(toolCalls) == 0 {
|
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||||
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
|
|
||||||
}
|
}
|
||||||
if msg["content"] != nil {
|
content, _ := msg["content"].(string)
|
||||||
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
|
if !strings.Contains(content, "下面是示例:") || !strings.Contains(content, "请勿执行。") || !strings.Contains(content, `"tool_calls"`) {
|
||||||
|
t.Fatalf("expected embedded example to remain plain text, got %#v", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,6 +316,36 @@ func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
large := strings.Repeat("a", 9000)
|
||||||
|
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
|
||||||
|
splitAt := len(payload) / 2
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamHasRawToolJSONContent(frames) {
|
||||||
|
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
@@ -500,15 +531,12 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
|||||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||||
}
|
}
|
||||||
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
|
||||||
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
|
|
||||||
}
|
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||||
@@ -542,15 +570,13 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
|||||||
if !strings.Contains(got, "我将调用工具。") {
|
if !strings.Contains(got, "我将调用工具。") {
|
||||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||||
}
|
}
|
||||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
|
||||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
|
||||||
}
|
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
|
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||||
@@ -583,15 +609,52 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
|
|||||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||||
}
|
}
|
||||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
|
||||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
|
||||||
}
|
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||||
|
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("did not expect tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
content := strings.Builder{}
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
if c, ok := delta["content"].(string); ok {
|
||||||
|
content.WriteString(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := content.String()
|
||||||
|
if !strings.Contains(got, "```json") || !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||||
|
t.Fatalf("expected fenced tool snippet in content, got=%q", got)
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "stop" {
|
||||||
|
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
spaces := strings.Repeat(" ", 200)
|
spaces := strings.Repeat(" ", 200)
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
@@ -612,9 +675,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
|||||||
if !streamHasToolCallsDelta(frames) {
|
if !streamHasToolCallsDelta(frames) {
|
||||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
if streamHasRawToolJSONContent(frames) {
|
|
||||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
|
||||||
}
|
|
||||||
content := strings.Builder{}
|
content := strings.Builder{}
|
||||||
for _, frame := range frames {
|
for _, frame := range frames {
|
||||||
choices, _ := frame["choices"].([]any)
|
choices, _ := frame["choices"].([]any)
|
||||||
@@ -627,9 +687,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
got := content.String()
|
got := content.String()
|
||||||
if strings.Contains(got, "{") {
|
|
||||||
t.Fatalf("unexpected suspicious prefix leak in content: %q", got)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got, "后置正文C。") {
|
if !strings.Contains(got, "后置正文C。") {
|
||||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||||
}
|
}
|
||||||
@@ -712,7 +769,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||||
@@ -735,8 +792,8 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
|||||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||||
}
|
}
|
||||||
argChunks := streamToolCallArgumentChunks(frames)
|
argChunks := streamToolCallArgumentChunks(frames)
|
||||||
if len(argChunks) < 2 {
|
if len(argChunks) == 0 {
|
||||||
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||||
}
|
}
|
||||||
joined := strings.Join(argChunks, "")
|
joined := strings.Join(argChunks, "")
|
||||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||||
@@ -34,9 +34,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": formatToolResultForPrompt(msg),
|
"content": formatToolResultForPrompt(msg),
|
||||||
})
|
})
|
||||||
case "user", "system":
|
case "user", "system", "developer":
|
||||||
out = append(out, map[string]any{
|
out = append(out, map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
@@ -48,7 +48,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
|||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
out = append(out, map[string]any{
|
out = append(out, map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": content,
|
"content": content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -78,7 +78,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) strin
|
|||||||
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||||
}
|
}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = "unknown"
|
continue
|
||||||
}
|
}
|
||||||
if args == "" {
|
if args == "" {
|
||||||
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||||
@@ -133,32 +133,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeOpenAIContentForPrompt(v any) string {
|
func normalizeOpenAIContentForPrompt(v any) string {
|
||||||
switch x := v.(type) {
|
return prompt.NormalizeContent(v)
|
||||||
case string:
|
|
||||||
return x
|
|
||||||
case []any:
|
|
||||||
parts := make([]string, 0, len(x))
|
|
||||||
for _, item := range x {
|
|
||||||
m, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
|
||||||
if t != "text" && t != "output_text" && t != "input_text" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if text := asString(m["text"]); text != "" {
|
|
||||||
parts = append(parts, text)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if text := asString(m["content"]); text != "" {
|
|
||||||
parts = append(parts, text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(parts, "\n")
|
|
||||||
default:
|
|
||||||
return marshalToPromptString(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||||
@@ -175,30 +150,11 @@ func normalizeToolArgumentString(raw string) string {
|
|||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if !looksLikeConcatenatedJSON(trimmed) {
|
if looksLikeConcatenatedJSON(trimmed) {
|
||||||
return trimmed
|
// Keep original payload to avoid silent argument rewrites.
|
||||||
|
return raw
|
||||||
}
|
}
|
||||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
return trimmed
|
||||||
values := make([]any, 0, 2)
|
|
||||||
for {
|
|
||||||
var v any
|
|
||||||
if err := dec.Decode(&v); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
values = append(values, v)
|
|
||||||
}
|
|
||||||
if len(values) < 2 {
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
last := values[len(values)-1]
|
|
||||||
b, err := json.Marshal(last)
|
|
||||||
if err != nil || len(b) == 0 {
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
return string(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalToPromptString(v any) string {
|
func marshalToPromptString(v any) string {
|
||||||
@@ -209,6 +165,14 @@ func marshalToPromptString(v any) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||||
|
role = strings.ToLower(strings.TrimSpace(role))
|
||||||
|
if role == "developer" {
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
|
||||||
func asString(v any) string {
|
func asString(v any) string {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
return s
|
return s
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) {
|
func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) {
|
||||||
raw := []any{
|
raw := []any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -189,10 +189,94 @@ func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *te
|
|||||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
}
|
}
|
||||||
content, _ := normalized[0]["content"].(string)
|
content, _ := normalized[0]["content"].(string)
|
||||||
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) {
|
if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
|
||||||
t.Fatalf("expected repaired arguments in tool history, got %q", content)
|
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||||
}
|
}
|
||||||
if strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
}
|
||||||
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
|
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_missing_name",
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"arguments": `{"path":"README.MD"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 0 {
|
||||||
|
t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": nil,
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_screenshot",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "send_file_to_user",
|
||||||
|
"arguments": `{"file_path":"/tmp/a.png"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
content, _ := normalized[0]["content"].(string)
|
||||||
|
if strings.Contains(content, "<|Assistant|>null") || strings.HasPrefix(strings.TrimSpace(content), "null") {
|
||||||
|
t.Fatalf("unexpected null literal injected into assistant tool history: %q", content)
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "function.name: send_file_to_user") {
|
||||||
|
t.Fatalf("expected tool history block preserved, got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{"role": "developer", "content": "必须先走工具调用"},
|
||||||
|
map[string]any{"role": "user", "content": "你好"},
|
||||||
|
}
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 2 {
|
||||||
|
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
if normalized[0]["role"] != "system" {
|
||||||
|
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextEmpty(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "", "content": "工具说明文本"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
content, _ := normalized[0]["content"].(string)
|
||||||
|
if content != "工具说明文本" {
|
||||||
|
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) {
|
func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) {
|
||||||
msgs := normalizeResponsesInputAsMessages([]any{
|
msgs := normalizeResponsesInputAsMessages([]any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
@@ -151,8 +151,8 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArg
|
|||||||
toolCalls, _ := m["tool_calls"].([]any)
|
toolCalls, _ := m["tool_calls"].([]any)
|
||||||
call, _ := toolCalls[0].(map[string]any)
|
call, _ := toolCalls[0].(map[string]any)
|
||||||
fn, _ := call["function"].(map[string]any)
|
fn, _ := call["function"].(map[string]any)
|
||||||
if fn["arguments"] != `{"q":"golang"}` {
|
if fn["arguments"] != `{}{"q":"golang"}` {
|
||||||
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"])
|
t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -113,15 +113,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||||
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
|
textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames)
|
||||||
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
|
|
||||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||||
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
|
|
||||||
|
|
||||||
callCount := len(textParsed.Calls)
|
callCount := len(textParsed.Calls)
|
||||||
if callCount == 0 {
|
|
||||||
callCount = len(thinkingParsed.Calls)
|
|
||||||
}
|
|
||||||
if toolChoice.IsRequired() && callCount == 0 {
|
if toolChoice.IsRequired() && callCount == 0 {
|
||||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
|||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"role": role,
|
"role": normalizeOpenAIRoleForPrompt(role),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
case "function_call_output", "tool_result":
|
case "function_call_output", "tool_result":
|
||||||
|
|||||||
@@ -102,16 +102,11 @@ func (s *responsesStreamRuntime) finalize() {
|
|||||||
|
|
||||||
if s.bufferToolContent {
|
if s.bufferToolContent {
|
||||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
|
textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||||
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
|
|
||||||
detected := textParsed.Calls
|
detected := textParsed.Calls
|
||||||
if len(detected) == 0 {
|
s.logToolPolicyRejections(textParsed)
|
||||||
detected = thinkingParsed.Calls
|
|
||||||
}
|
|
||||||
s.logToolPolicyRejections(textParsed, thinkingParsed)
|
|
||||||
|
|
||||||
if len(detected) > 0 {
|
if len(detected) > 0 {
|
||||||
s.toolCallsEmitted = true
|
s.toolCallsEmitted = true
|
||||||
@@ -157,7 +152,7 @@ func (s *responsesStreamRuntime) finalize() {
|
|||||||
s.sendDone()
|
s.sendDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) {
|
||||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||||
@@ -172,7 +167,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
logRejected(textParsed, "text")
|
logRejected(textParsed, "text")
|
||||||
logRejected(thinkingParsed, "thinking")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||||
@@ -207,9 +201,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
|||||||
}
|
}
|
||||||
s.thinking.WriteString(p.Text)
|
s.thinking.WriteString(p.Text)
|
||||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||||
if s.bufferToolContent {
|
|
||||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,16 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
|||||||
outputIndex := s.ensureMessageOutputIndex()
|
outputIndex := s.ensureMessageOutputIndex()
|
||||||
text := s.visibleText.String()
|
text := s.visibleText.String()
|
||||||
if s.messagePartAdded {
|
if s.messagePartAdded {
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_text.done",
|
||||||
|
openaifmt.BuildResponsesTextDonePayload(
|
||||||
|
s.responseID,
|
||||||
|
itemID,
|
||||||
|
outputIndex,
|
||||||
|
0,
|
||||||
|
text,
|
||||||
|
),
|
||||||
|
)
|
||||||
s.sendEvent(
|
s.sendEvent(
|
||||||
"response.content_part.done",
|
"response.content_part.done",
|
||||||
openaifmt.BuildResponsesContentPartDonePayload(
|
openaifmt.BuildResponsesContentPartDonePayload(
|
||||||
|
|||||||
@@ -99,9 +99,6 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
|||||||
if !strings.Contains(body, "event: response.output_item.done") {
|
if !strings.Contains(body, "event: response.output_item.done") {
|
||||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||||
}
|
}
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
|
||||||
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
|
||||||
}
|
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||||
}
|
}
|
||||||
@@ -229,6 +226,40 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.output_text.done") {
|
||||||
|
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||||
|
}
|
||||||
|
textDoneIdx := strings.Index(body, "event: response.output_text.done")
|
||||||
|
partDoneIdx := strings.Index(body, "event: response.content_part.done")
|
||||||
|
if textDoneIdx < 0 || partDoneIdx < 0 {
|
||||||
|
t.Fatalf("expected output_text.done + content_part.done, body=%s", body)
|
||||||
|
}
|
||||||
|
if textDoneIdx > partDoneIdx {
|
||||||
|
t.Fatalf("expected output_text.done before content_part.done, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
@@ -266,7 +297,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) {
|
func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -291,23 +322,8 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
|
|||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||||
if len(addedPayloads) < 2 {
|
if len(addedPayloads) < 1 {
|
||||||
t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String())
|
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||||
}
|
|
||||||
|
|
||||||
indexes := map[int]struct{}{}
|
|
||||||
typeByIndex := map[int]string{}
|
|
||||||
addedIDs := map[string]string{}
|
|
||||||
for _, payload := range addedPayloads {
|
|
||||||
item, _ := payload["item"].(map[string]any)
|
|
||||||
itemType := strings.TrimSpace(asString(item["type"]))
|
|
||||||
outputIndex := int(asFloat(payload["output_index"]))
|
|
||||||
if _, exists := indexes[outputIndex]; exists {
|
|
||||||
t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload)
|
|
||||||
}
|
|
||||||
indexes[outputIndex] = struct{}{}
|
|
||||||
typeByIndex[outputIndex] = itemType
|
|
||||||
addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"]))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||||
@@ -316,20 +332,21 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
|
|||||||
}
|
}
|
||||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||||
output, _ := responseObj["output"].([]any)
|
output, _ := responseObj["output"].([]any)
|
||||||
found := map[string]bool{}
|
hasMessage := false
|
||||||
for _, item := range output {
|
for _, item := range output {
|
||||||
m, _ := item.(map[string]any)
|
m, _ := item.(map[string]any)
|
||||||
itemType := strings.TrimSpace(asString(m["type"]))
|
if m == nil {
|
||||||
itemID := strings.TrimSpace(asString(m["id"]))
|
|
||||||
if itemType == "" || itemID == "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID {
|
if asString(m["type"]) == "message" {
|
||||||
found[itemType] = true
|
hasMessage = true
|
||||||
|
}
|
||||||
|
if asString(m["type"]) == "function_call" {
|
||||||
|
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found["message"] || !found["function_call"] {
|
if !hasMessage {
|
||||||
t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output)
|
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,7 +377,7 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
|
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -373,7 +390,7 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
|||||||
return "data: " + string(b) + "\n"
|
return "data: " + string(b) + "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
|
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusOK,
|
StatusCode: http.StatusOK,
|
||||||
@@ -382,14 +399,11 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
|||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
|
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||||
}
|
}
|
||||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||||
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
|
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||||
}
|
|
||||||
if !strings.Contains(body, "event: response.output_item.done") {
|
|
||||||
t.Fatalf("expected runtime to close function output item, body=%s", body)
|
|
||||||
}
|
}
|
||||||
if !strings.Contains(body, "event: response.completed") {
|
if !strings.Contains(body, "event: response.completed") {
|
||||||
t.Fatalf("expected response.completed event, body=%s", body)
|
t.Fatalf("expected response.completed event, body=%s", body)
|
||||||
@@ -430,6 +444,42 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(path, value string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": path,
|
||||||
|
"v": value,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||||
|
sseLine("response/content", "plain text only") +
|
||||||
|
"data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.failed") {
|
||||||
|
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: response.completed") {
|
||||||
|
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
@@ -516,6 +566,33 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||||
|
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||||
|
`data: [DONE]` + "\n",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
|
||||||
|
if rec.Code != http.StatusUnprocessableEntity {
|
||||||
|
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||||
|
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -167,19 +167,15 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
|||||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||||
}
|
}
|
||||||
outputText, _ := out["output_text"].(string)
|
outputText, _ := out["output_text"].(string)
|
||||||
if outputText != "" {
|
if outputText == "" {
|
||||||
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
|
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||||
}
|
}
|
||||||
output, _ := out["output"].([]any)
|
output, _ := out["output"].([]any)
|
||||||
hasFunctionCall := false
|
if len(output) != 1 {
|
||||||
for _, item := range output {
|
t.Fatalf("expected one output item, got %#v", output)
|
||||||
m, _ := item.(map[string]any)
|
|
||||||
if m != nil && m["type"] == "function_call" {
|
|
||||||
hasFunctionCall = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if !hasFunctionCall {
|
first, _ := output[0].(map[string]any)
|
||||||
t.Fatalf("expected function_call output item, got %#v", output)
|
if first["type"] != "message" {
|
||||||
|
t.Fatalf("expected message output item, got %#v", output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
state.pending.WriteString(chunk)
|
state.pending.WriteString(chunk)
|
||||||
}
|
}
|
||||||
events := make([]toolStreamEvent, 0, 2)
|
events := make([]toolStreamEvent, 0, 2)
|
||||||
|
if len(state.pendingToolCalls) > 0 {
|
||||||
|
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||||
|
state.pendingToolRaw = ""
|
||||||
|
state.pendingToolCalls = nil
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if state.capturing {
|
if state.capturing {
|
||||||
@@ -21,32 +26,30 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
state.capture.WriteString(state.pending.String())
|
state.capture.WriteString(state.pending.String())
|
||||||
state.pending.Reset()
|
state.pending.Reset()
|
||||||
}
|
}
|
||||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
|
||||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
|
||||||
}
|
|
||||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||||
if !ready {
|
if !ready {
|
||||||
if state.capture.Len() > toolSieveCaptureLimit {
|
|
||||||
content := state.capture.String()
|
|
||||||
state.capture.Reset()
|
|
||||||
state.capturing = false
|
|
||||||
state.resetIncrementalToolState()
|
|
||||||
state.noteText(content)
|
|
||||||
events = append(events, toolStreamEvent{Content: content})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
captured := state.capture.String()
|
||||||
state.capture.Reset()
|
state.capture.Reset()
|
||||||
state.capturing = false
|
state.capturing = false
|
||||||
state.resetIncrementalToolState()
|
state.resetIncrementalToolState()
|
||||||
|
if len(calls) > 0 {
|
||||||
|
if prefix != "" {
|
||||||
|
state.noteText(prefix)
|
||||||
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
|
}
|
||||||
|
if suffix != "" {
|
||||||
|
state.pending.WriteString(suffix)
|
||||||
|
}
|
||||||
|
_ = captured
|
||||||
|
state.pendingToolCalls = calls
|
||||||
|
continue
|
||||||
|
}
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
state.noteText(prefix)
|
state.noteText(prefix)
|
||||||
events = append(events, toolStreamEvent{Content: prefix})
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
}
|
}
|
||||||
if len(calls) > 0 {
|
|
||||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
|
||||||
}
|
|
||||||
if suffix != "" {
|
if suffix != "" {
|
||||||
state.pending.WriteString(suffix)
|
state.pending.WriteString(suffix)
|
||||||
}
|
}
|
||||||
@@ -89,6 +92,11 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
events := processToolSieveChunk(state, "", toolNames)
|
events := processToolSieveChunk(state, "", toolNames)
|
||||||
|
if len(state.pendingToolCalls) > 0 {
|
||||||
|
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||||
|
state.pendingToolRaw = ""
|
||||||
|
state.pendingToolCalls = nil
|
||||||
|
}
|
||||||
if state.capturing {
|
if state.capturing {
|
||||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||||
if ready {
|
if ready {
|
||||||
|
|||||||
@@ -7,17 +7,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type toolStreamSieveState struct {
|
type toolStreamSieveState struct {
|
||||||
pending strings.Builder
|
pending strings.Builder
|
||||||
capture strings.Builder
|
capture strings.Builder
|
||||||
capturing bool
|
capturing bool
|
||||||
recentTextTail string
|
recentTextTail string
|
||||||
disableDeltas bool
|
pendingToolRaw string
|
||||||
toolNameSent bool
|
pendingToolCalls []util.ParsedToolCall
|
||||||
toolName string
|
disableDeltas bool
|
||||||
toolArgsStart int
|
toolNameSent bool
|
||||||
toolArgsSent int
|
toolName string
|
||||||
toolArgsString bool
|
toolArgsStart int
|
||||||
toolArgsDone bool
|
toolArgsSent int
|
||||||
|
toolArgsString bool
|
||||||
|
toolArgsDone bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type toolStreamEvent struct {
|
type toolStreamEvent struct {
|
||||||
@@ -32,7 +34,6 @@ type toolCallDelta struct {
|
|||||||
Arguments string
|
Arguments string
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolSieveCaptureLimit = 8 * 1024
|
|
||||||
const toolSieveContextTailLimit = 256
|
const toolSieveContextTailLimit = 256
|
||||||
|
|
||||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||||
|
|||||||
@@ -1,128 +1,133 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
"ds2api/internal/config"
|
|
||||||
)
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
|
||||||
page := intFromQuery(r, "page", 1)
|
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
pageSize := intFromQuery(r, "page_size", 10)
|
page := intFromQuery(r, "page", 1)
|
||||||
if page < 1 {
|
pageSize := intFromQuery(r, "page_size", 10)
|
||||||
page = 1
|
if page < 1 {
|
||||||
}
|
page = 1
|
||||||
if pageSize < 1 {
|
}
|
||||||
pageSize = 1
|
if pageSize < 1 {
|
||||||
}
|
pageSize = 1
|
||||||
if pageSize > 100 {
|
}
|
||||||
pageSize = 100
|
if pageSize > 100 {
|
||||||
}
|
pageSize = 100
|
||||||
accounts := h.Store.Snapshot().Accounts
|
}
|
||||||
reverseAccounts(accounts)
|
accounts := h.Store.Snapshot().Accounts
|
||||||
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
|
reverseAccounts(accounts)
|
||||||
if q != "" {
|
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
|
||||||
filtered := make([]config.Account, 0, len(accounts))
|
if q != "" {
|
||||||
for _, acc := range accounts {
|
filtered := make([]config.Account, 0, len(accounts))
|
||||||
id := strings.ToLower(acc.Identifier())
|
for _, acc := range accounts {
|
||||||
if strings.Contains(id, q) ||
|
id := strings.ToLower(acc.Identifier())
|
||||||
strings.Contains(strings.ToLower(acc.Email), q) ||
|
if strings.Contains(id, q) ||
|
||||||
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
strings.Contains(strings.ToLower(acc.Email), q) ||
|
||||||
filtered = append(filtered, acc)
|
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
||||||
}
|
filtered = append(filtered, acc)
|
||||||
}
|
}
|
||||||
accounts = filtered
|
}
|
||||||
}
|
accounts = filtered
|
||||||
total := len(accounts)
|
}
|
||||||
totalPages := 1
|
total := len(accounts)
|
||||||
if total > 0 {
|
totalPages := 1
|
||||||
totalPages = (total + pageSize - 1) / pageSize
|
if total > 0 {
|
||||||
}
|
totalPages = (total + pageSize - 1) / pageSize
|
||||||
start := (page - 1) * pageSize
|
}
|
||||||
if start > total {
|
start := (page - 1) * pageSize
|
||||||
start = total
|
if start > total {
|
||||||
}
|
start = total
|
||||||
end := start + pageSize
|
}
|
||||||
if end > total {
|
end := start + pageSize
|
||||||
end = total
|
if end > total {
|
||||||
}
|
end = total
|
||||||
items := make([]map[string]any, 0, end-start)
|
}
|
||||||
for _, acc := range accounts[start:end] {
|
items := make([]map[string]any, 0, end-start)
|
||||||
token := strings.TrimSpace(acc.Token)
|
for _, acc := range accounts[start:end] {
|
||||||
preview := ""
|
token := strings.TrimSpace(acc.Token)
|
||||||
if token != "" {
|
preview := ""
|
||||||
if len(token) > 20 {
|
if token != "" {
|
||||||
preview = token[:20] + "..."
|
if len(token) > 20 {
|
||||||
} else {
|
preview = token[:20] + "..."
|
||||||
preview = token
|
} else {
|
||||||
}
|
preview = token
|
||||||
}
|
}
|
||||||
items = append(items, map[string]any{
|
}
|
||||||
"identifier": acc.Identifier(),
|
items = append(items, map[string]any{
|
||||||
"email": acc.Email,
|
"identifier": acc.Identifier(),
|
||||||
"mobile": acc.Mobile,
|
"email": acc.Email,
|
||||||
"has_password": acc.Password != "",
|
"mobile": acc.Mobile,
|
||||||
"has_token": token != "",
|
"has_password": acc.Password != "",
|
||||||
"token_preview": preview,
|
"has_token": token != "",
|
||||||
"test_status": acc.TestStatus,
|
"token_preview": preview,
|
||||||
})
|
"test_status": acc.TestStatus,
|
||||||
}
|
})
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
}
|
||||||
}
|
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||||
|
}
|
||||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var req map[string]any
|
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
var req map[string]any
|
||||||
acc := toAccount(req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if acc.Identifier() == "" {
|
acc := toAccount(req)
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
if acc.Identifier() == "" {
|
||||||
return
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||||
}
|
return
|
||||||
err := h.Store.Update(func(c *config.Config) error {
|
}
|
||||||
for _, a := range c.Accounts {
|
err := h.Store.Update(func(c *config.Config) error {
|
||||||
if acc.Email != "" && a.Email == acc.Email {
|
mobileKey := config.CanonicalMobileKey(acc.Mobile)
|
||||||
return fmt.Errorf("邮箱已存在")
|
for _, a := range c.Accounts {
|
||||||
}
|
if acc.Email != "" && a.Email == acc.Email {
|
||||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
return fmt.Errorf("邮箱已存在")
|
||||||
return fmt.Errorf("手机号已存在")
|
}
|
||||||
}
|
if mobileKey != "" && config.CanonicalMobileKey(a.Mobile) == mobileKey {
|
||||||
}
|
return fmt.Errorf("手机号已存在")
|
||||||
c.Accounts = append(c.Accounts, acc)
|
}
|
||||||
return nil
|
}
|
||||||
})
|
c.Accounts = append(c.Accounts, acc)
|
||||||
if err != nil {
|
return nil
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
})
|
||||||
return
|
if err != nil {
|
||||||
}
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||||
h.Pool.Reset()
|
return
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
}
|
||||||
}
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
}
|
||||||
identifier := chi.URLParam(r, "identifier")
|
|
||||||
err := h.Store.Update(func(c *config.Config) error {
|
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
idx := -1
|
identifier := chi.URLParam(r, "identifier")
|
||||||
for i, a := range c.Accounts {
|
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||||
if accountMatchesIdentifier(a, identifier) {
|
identifier = decoded
|
||||||
idx = i
|
}
|
||||||
break
|
err := h.Store.Update(func(c *config.Config) error {
|
||||||
}
|
idx := -1
|
||||||
}
|
for i, a := range c.Accounts {
|
||||||
if idx < 0 {
|
if accountMatchesIdentifier(a, identifier) {
|
||||||
return fmt.Errorf("账号不存在")
|
idx = i
|
||||||
}
|
break
|
||||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
}
|
||||||
return nil
|
}
|
||||||
})
|
if idx < 0 {
|
||||||
if err != nil {
|
return fmt.Errorf("账号不存在")
|
||||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
}
|
||||||
return
|
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||||
}
|
return nil
|
||||||
h.Pool.Reset()
|
})
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
if err != nil {
|
||||||
}
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -102,6 +103,45 @@ func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteAccountSupportsEncodedPlusMobile(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape("+8613800138000"), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 0 {
|
||||||
|
t.Fatalf("expected account removed, remaining=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Post("/admin/accounts", h.addAccount)
|
||||||
|
body := []byte(`{"mobile":"13800138000","password":"pwd2"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/admin/accounts", bytes.NewReader(body))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 1 {
|
||||||
|
t.Fatalf("expected no duplicate insert, got=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||||
h := newAdminTestHandler(t, `{
|
h := newAdminTestHandler(t, `{
|
||||||
"accounts":[
|
"accounts":[
|
||||||
@@ -117,6 +157,13 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
|||||||
if accByMobile.Email != "u@example.com" {
|
if accByMobile.Email != "u@example.com" {
|
||||||
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
||||||
}
|
}
|
||||||
|
accByMobileWithCountryCode, ok := findAccountByIdentifier(h.Store, "+8613800138000")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected find by +86 mobile")
|
||||||
|
}
|
||||||
|
if accByMobileWithCountryCode.Email != "u@example.com" {
|
||||||
|
t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode)
|
||||||
|
}
|
||||||
|
|
||||||
tokenOnlyID := ""
|
tokenOnlyID := ""
|
||||||
for _, acc := range h.Store.Accounts() {
|
for _, acc := range h.Store.Accounts() {
|
||||||
|
|||||||
@@ -1,209 +1,212 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
authn "ds2api/internal/auth"
|
authn "ds2api/internal/auth"
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
identifier, _ := req["identifier"].(string)
|
identifier, _ := req["identifier"].(string)
|
||||||
if strings.TrimSpace(identifier) == "" {
|
if strings.TrimSpace(identifier) == "" {
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||||
if !ok {
|
if !ok {
|
||||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model, _ := req["model"].(string)
|
model, _ := req["model"].(string)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = "deepseek-chat"
|
model = "deepseek-chat"
|
||||||
}
|
}
|
||||||
message, _ := req["message"].(string)
|
message, _ := req["message"].(string)
|
||||||
result := h.testAccount(r.Context(), acc, model, message)
|
result := h.testAccount(r.Context(), acc, model, message)
|
||||||
writeJSON(w, http.StatusOK, result)
|
writeJSON(w, http.StatusOK, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
model, _ := req["model"].(string)
|
model, _ := req["model"].(string)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = "deepseek-chat"
|
model = "deepseek-chat"
|
||||||
}
|
}
|
||||||
accounts := h.Store.Snapshot().Accounts
|
accounts := h.Store.Snapshot().Accounts
|
||||||
if len(accounts) == 0 {
|
if len(accounts) == 0 {
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concurrent testing with a semaphore to limit parallelism.
|
// Concurrent testing with a semaphore to limit parallelism.
|
||||||
const maxConcurrency = 5
|
const maxConcurrency = 5
|
||||||
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
||||||
return h.testAccount(r.Context(), account, model, "")
|
return h.testAccount(r.Context(), account, model, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
success := 0
|
success := 0
|
||||||
for _, res := range results {
|
for _, res := range results {
|
||||||
if ok, _ := res["success"].(bool); ok {
|
if ok, _ := res["success"].(bool); ok {
|
||||||
success++
|
success++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||||
}
|
}
|
||||||
|
|
||||||
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
||||||
if maxConcurrency <= 0 {
|
if maxConcurrency <= 0 {
|
||||||
maxConcurrency = 1
|
maxConcurrency = 1
|
||||||
}
|
}
|
||||||
sem := make(chan struct{}, maxConcurrency)
|
sem := make(chan struct{}, maxConcurrency)
|
||||||
results := make([]map[string]any, len(accounts))
|
results := make([]map[string]any, len(accounts))
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i, acc := range accounts {
|
for i, acc := range accounts {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(idx int, account config.Account) {
|
go func(idx int, account config.Account) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
sem <- struct{}{} // acquire
|
sem <- struct{}{} // acquire
|
||||||
defer func() { <-sem }() // release
|
defer func() { <-sem }() // release
|
||||||
results[idx] = testFn(idx, account)
|
results[idx] = testFn(idx, account)
|
||||||
}(i, acc)
|
}(i, acc)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
identifier := acc.Identifier()
|
identifier := acc.Identifier()
|
||||||
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
||||||
defer func() {
|
defer func() {
|
||||||
status := "failed"
|
status := "failed"
|
||||||
if ok, _ := result["success"].(bool); ok {
|
if ok, _ := result["success"].(bool); ok {
|
||||||
status = "ok"
|
status = "ok"
|
||||||
}
|
}
|
||||||
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
||||||
}()
|
}()
|
||||||
token := strings.TrimSpace(acc.Token)
|
token := strings.TrimSpace(acc.Token)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
newToken, err := h.DS.Login(ctx, acc)
|
newToken, err := h.DS.Login(ctx, acc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
result["message"] = "登录失败: " + err.Error()
|
result["message"] = "登录失败: " + err.Error()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
token = newToken
|
token = newToken
|
||||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||||
}
|
}
|
||||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newToken, loginErr := h.DS.Login(ctx, acc)
|
newToken, loginErr := h.DS.Login(ctx, acc)
|
||||||
if loginErr != nil {
|
if loginErr != nil {
|
||||||
result["message"] = "创建会话失败: " + err.Error()
|
result["message"] = "创建会话失败: " + err.Error()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
token = newToken
|
token = newToken
|
||||||
authCtx.DeepSeekToken = token
|
authCtx.DeepSeekToken = token
|
||||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
result["message"] = "创建会话失败: " + err.Error()
|
result["message"] = "创建会话失败: " + err.Error()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(message) == "" {
|
if strings.TrimSpace(message) == "" {
|
||||||
message = "你是谁?"
|
result["success"] = true
|
||||||
}
|
result["message"] = "API 测试成功(仅会话创建)"
|
||||||
thinking, search, ok := config.GetModelConfig(model)
|
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||||
if !ok {
|
return result
|
||||||
thinking, search = false, false
|
}
|
||||||
}
|
thinking, search, ok := config.GetModelConfig(model)
|
||||||
_ = search
|
if !ok {
|
||||||
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
thinking, search = false, false
|
||||||
if err != nil {
|
}
|
||||||
result["message"] = "获取 PoW 失败: " + err.Error()
|
_ = search
|
||||||
return result
|
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
||||||
}
|
if err != nil {
|
||||||
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
result["message"] = "获取 PoW 失败: " + err.Error()
|
||||||
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
return result
|
||||||
if err != nil {
|
}
|
||||||
result["message"] = "请求失败: " + err.Error()
|
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
||||||
return result
|
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
||||||
}
|
if err != nil {
|
||||||
if resp.StatusCode != http.StatusOK {
|
result["message"] = "请求失败: " + err.Error()
|
||||||
defer resp.Body.Close()
|
return result
|
||||||
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
}
|
||||||
return result
|
if resp.StatusCode != http.StatusOK {
|
||||||
}
|
defer resp.Body.Close()
|
||||||
collected := sse.CollectStream(resp, thinking, true)
|
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
||||||
result["success"] = true
|
return result
|
||||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
}
|
||||||
if collected.Text != "" {
|
collected := sse.CollectStream(resp, thinking, true)
|
||||||
result["message"] = collected.Text
|
result["success"] = true
|
||||||
} else {
|
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||||
result["message"] = "(无回复内容)"
|
if collected.Text != "" {
|
||||||
}
|
result["message"] = collected.Text
|
||||||
if collected.Thinking != "" {
|
} else {
|
||||||
result["thinking"] = collected.Thinking
|
result["message"] = "(无回复内容)"
|
||||||
}
|
}
|
||||||
return result
|
if collected.Thinking != "" {
|
||||||
}
|
result["thinking"] = collected.Thinking
|
||||||
|
}
|
||||||
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
return result
|
||||||
var req map[string]any
|
}
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
model, _ := req["model"].(string)
|
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||||
message, _ := req["message"].(string)
|
var req map[string]any
|
||||||
apiKey, _ := req["api_key"].(string)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if model == "" {
|
model, _ := req["model"].(string)
|
||||||
model = "deepseek-chat"
|
message, _ := req["message"].(string)
|
||||||
}
|
apiKey, _ := req["api_key"].(string)
|
||||||
if message == "" {
|
if model == "" {
|
||||||
message = "你好"
|
model = "deepseek-chat"
|
||||||
}
|
}
|
||||||
if apiKey == "" {
|
if message == "" {
|
||||||
keys := h.Store.Snapshot().Keys
|
message = "你好"
|
||||||
if len(keys) == 0 {
|
}
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
if apiKey == "" {
|
||||||
return
|
keys := h.Store.Snapshot().Keys
|
||||||
}
|
if len(keys) == 0 {
|
||||||
apiKey = keys[0]
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
||||||
}
|
return
|
||||||
host := r.Host
|
}
|
||||||
scheme := "http"
|
apiKey = keys[0]
|
||||||
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
}
|
||||||
scheme = "https"
|
host := r.Host
|
||||||
}
|
scheme := "http"
|
||||||
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
||||||
b, _ := json.Marshal(payload)
|
scheme = "https"
|
||||||
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
}
|
||||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
||||||
request.Header.Set("Content-Type", "application/json")
|
b, _ := json.Marshal(payload)
|
||||||
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
||||||
if err != nil {
|
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
request.Header.Set("Content-Type", "application/json")
|
||||||
return
|
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
||||||
}
|
if err != nil {
|
||||||
defer resp.Body.Close()
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
||||||
body, _ := io.ReadAll(resp.Body)
|
return
|
||||||
if resp.StatusCode == http.StatusOK {
|
}
|
||||||
var parsed any
|
defer resp.Body.Close()
|
||||||
_ = json.Unmarshal(body, &parsed)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
if resp.StatusCode == http.StatusOK {
|
||||||
return
|
var parsed any
|
||||||
}
|
_ = json.Unmarshal(body, &parsed)
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||||
|
}
|
||||||
|
|||||||
76
internal/admin/handler_accounts_testing_test.go
Normal file
76
internal/admin/handler_accounts_testing_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testingDSMock struct {
|
||||||
|
loginCalls int
|
||||||
|
createSessionCalls int
|
||||||
|
getPowCalls int
|
||||||
|
callCompletionCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) Login(_ context.Context, _ config.Account) (string, error) {
|
||||||
|
m.loginCalls++
|
||||||
|
return "new-token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
m.createSessionCalls++
|
||||||
|
return "session-id", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
m.getPowCalls++
|
||||||
|
return "", errors.New("should not call GetPow in this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testingDSMock) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||||
|
m.callCompletionCalls++
|
||||||
|
return nil, errors.New("should not call CallCompletion in this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":""}]}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
ds := &testingDSMock{}
|
||||||
|
h := &Handler{Store: store, DS: ds}
|
||||||
|
acc, ok := store.FindAccount("batch@example.com")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected test account")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := h.testAccount(context.Background(), acc, "deepseek-chat", "")
|
||||||
|
|
||||||
|
if ok, _ := result["success"].(bool); !ok {
|
||||||
|
t.Fatalf("expected success=true, got %#v", result)
|
||||||
|
}
|
||||||
|
msg, _ := result["message"].(string)
|
||||||
|
if !strings.Contains(msg, "仅会话创建") {
|
||||||
|
t.Fatalf("expected session-only success message, got %q", msg)
|
||||||
|
}
|
||||||
|
if ds.loginCalls != 1 || ds.createSessionCalls != 1 {
|
||||||
|
t.Fatalf("unexpected Login/CreateSession calls: login=%d createSession=%d", ds.loginCalls, ds.createSessionCalls)
|
||||||
|
}
|
||||||
|
if ds.getPowCalls != 0 || ds.callCompletionCalls != 0 {
|
||||||
|
t.Fatalf("expected no completion flow calls, got getPow=%d callCompletion=%d", ds.getPowCalls, ds.callCompletionCalls)
|
||||||
|
}
|
||||||
|
updated, ok := store.FindAccount("batch@example.com")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected updated account")
|
||||||
|
}
|
||||||
|
if updated.Token != "new-token" {
|
||||||
|
t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token)
|
||||||
|
}
|
||||||
|
if updated.TestStatus != "ok" {
|
||||||
|
t.Fatalf("expected test status ok, got %q", updated.TestStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -49,6 +49,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
next := c.Clone()
|
next := c.Clone()
|
||||||
if mode == "replace" {
|
if mode == "replace" {
|
||||||
next = incoming.Clone()
|
next = incoming.Clone()
|
||||||
|
next.Accounts = normalizeAndDedupeAccounts(next.Accounts)
|
||||||
next.VercelSyncHash = c.VercelSyncHash
|
next.VercelSyncHash = c.VercelSyncHash
|
||||||
next.VercelSyncTime = c.VercelSyncTime
|
next.VercelSyncTime = c.VercelSyncTime
|
||||||
importedKeys = len(next.Keys)
|
importedKeys = len(next.Keys)
|
||||||
@@ -73,17 +74,22 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
existingAccounts := map[string]struct{}{}
|
existingAccounts := map[string]struct{}{}
|
||||||
for _, acc := range next.Accounts {
|
for _, acc := range next.Accounts {
|
||||||
existingAccounts[acc.Identifier()] = struct{}{}
|
acc = normalizeAccountForStorage(acc)
|
||||||
|
key := accountDedupeKey(acc)
|
||||||
|
if key != "" {
|
||||||
|
existingAccounts[key] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, acc := range incoming.Accounts {
|
for _, acc := range incoming.Accounts {
|
||||||
id := acc.Identifier()
|
acc = normalizeAccountForStorage(acc)
|
||||||
if id == "" {
|
key := accountDedupeKey(acc)
|
||||||
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, ok := existingAccounts[id]; ok {
|
if _, ok := existingAccounts[key]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
existingAccounts[id] = struct{}{}
|
existingAccounts[key] = struct{}{}
|
||||||
next.Accounts = append(next.Accounts, acc)
|
next.Accounts = append(next.Accounts, acc)
|
||||||
importedAccounts++
|
importedAccounts++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,17 +25,28 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||||
existing := map[string]config.Account{}
|
existing := map[string]config.Account{}
|
||||||
for _, a := range old.Accounts {
|
for _, a := range old.Accounts {
|
||||||
existing[a.Identifier()] = a
|
a = normalizeAccountForStorage(a)
|
||||||
|
key := accountDedupeKey(a)
|
||||||
|
if key != "" {
|
||||||
|
existing[key] = a
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
seen := map[string]struct{}{}
|
||||||
accounts := make([]config.Account, 0, len(accountsRaw))
|
accounts := make([]config.Account, 0, len(accountsRaw))
|
||||||
for _, item := range accountsRaw {
|
for _, item := range accountsRaw {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
acc := toAccount(m)
|
acc := normalizeAccountForStorage(toAccount(m))
|
||||||
id := acc.Identifier()
|
key := accountDedupeKey(acc)
|
||||||
if prev, ok := existing[id]; ok {
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if prev, ok := existing[key]; ok {
|
||||||
if strings.TrimSpace(acc.Password) == "" {
|
if strings.TrimSpace(acc.Password) == "" {
|
||||||
acc.Password = prev.Password
|
acc.Password = prev.Password
|
||||||
}
|
}
|
||||||
@@ -43,6 +54,7 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
acc.Token = prev.Token
|
acc.Token = prev.Token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
accounts = append(accounts, acc)
|
accounts = append(accounts, acc)
|
||||||
}
|
}
|
||||||
c.Accounts = accounts
|
c.Accounts = accounts
|
||||||
@@ -138,20 +150,24 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
if accounts, ok := req["accounts"].([]any); ok {
|
if accounts, ok := req["accounts"].([]any); ok {
|
||||||
existing := map[string]bool{}
|
existing := map[string]bool{}
|
||||||
for _, a := range c.Accounts {
|
for _, a := range c.Accounts {
|
||||||
existing[a.Identifier()] = true
|
a = normalizeAccountForStorage(a)
|
||||||
|
key := accountDedupeKey(a)
|
||||||
|
if key != "" {
|
||||||
|
existing[key] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, item := range accounts {
|
for _, item := range accounts {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
acc := toAccount(m)
|
acc := normalizeAccountForStorage(toAccount(m))
|
||||||
id := acc.Identifier()
|
key := accountDedupeKey(acc)
|
||||||
if id == "" || existing[id] {
|
if key == "" || existing[key] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c.Accounts = append(c.Accounts, acc)
|
c.Accounts = append(c.Accounts, acc)
|
||||||
existing[id] = true
|
existing[key] = true
|
||||||
importedAccounts++
|
importedAccounts++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -265,3 +265,57 @@ func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) {
|
|||||||
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
|
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigImportMergeDedupesMobileAliases(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"keys":["k1"],
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"p1"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
merge := map[string]any{
|
||||||
|
"mode": "merge",
|
||||||
|
"config": map[string]any{
|
||||||
|
"accounts": []any{
|
||||||
|
map[string]any{"mobile": "13800138000", "password": "p2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(merge)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.configImport(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 1 {
|
||||||
|
t.Fatalf("expected merge dedupe by canonical mobile, got=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigDedupesMobileAliases(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"keys":["k1"],
|
||||||
|
"accounts":[{"mobile":"+8613800138000","password":"old"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"accounts": []any{
|
||||||
|
map[string]any{"mobile": "+8613800138000"},
|
||||||
|
map[string]any{"mobile": "13800138000"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(reqBody)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.updateConfig(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
accounts := h.Store.Accounts()
|
||||||
|
if len(accounts) != 1 {
|
||||||
|
t.Fatalf("expected update dedupe by canonical mobile, got=%d", len(accounts))
|
||||||
|
}
|
||||||
|
if accounts[0].Identifier() != "+8613800138000" {
|
||||||
|
t.Fatalf("unexpected identifier: %q", accounts[0].Identifier())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -59,9 +59,11 @@ func toStringSlice(v any) ([]string, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toAccount(m map[string]any) config.Account {
|
func toAccount(m map[string]any) config.Account {
|
||||||
|
email := fieldString(m, "email")
|
||||||
|
mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile"))
|
||||||
return config.Account{
|
return config.Account{
|
||||||
Email: fieldString(m, "email"),
|
Email: email,
|
||||||
Mobile: fieldString(m, "mobile"),
|
Mobile: mobile,
|
||||||
Password: fieldString(m, "password"),
|
Password: fieldString(m, "password"),
|
||||||
Token: fieldString(m, "token"),
|
Token: fieldString(m, "token"),
|
||||||
}
|
}
|
||||||
@@ -90,12 +92,52 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
|||||||
if strings.TrimSpace(acc.Email) == id {
|
if strings.TrimSpace(acc.Email) == id {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(acc.Mobile) == id {
|
if mobileKey := config.CanonicalMobileKey(id); mobileKey != "" && mobileKey == config.CanonicalMobileKey(acc.Mobile) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return acc.Identifier() == id
|
return acc.Identifier() == id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeAccountForStorage(acc config.Account) config.Account {
|
||||||
|
acc.Email = strings.TrimSpace(acc.Email)
|
||||||
|
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
|
||||||
|
return acc
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountDedupeKey(acc config.Account) string {
|
||||||
|
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||||
|
return "email:" + email
|
||||||
|
}
|
||||||
|
if mobile := config.CanonicalMobileKey(acc.Mobile); mobile != "" {
|
||||||
|
return "mobile:" + mobile
|
||||||
|
}
|
||||||
|
if id := strings.TrimSpace(acc.Identifier()); id != "" {
|
||||||
|
return "id:" + id
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAndDedupeAccounts(accounts []config.Account) []config.Account {
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]config.Account, 0, len(accounts))
|
||||||
|
seen := make(map[string]struct{}, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
acc = normalizeAccountForStorage(acc)
|
||||||
|
key := accountDedupeKey(acc)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
out = append(out, acc)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
|
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
|
||||||
id := strings.TrimSpace(identifier)
|
id := strings.TrimSpace(identifier)
|
||||||
if id == "" {
|
if id == "" {
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ func TestToAccountAllFields(t *testing.T) {
|
|||||||
if acc.Email != "user@test.com" {
|
if acc.Email != "user@test.com" {
|
||||||
t.Fatalf("unexpected email: %q", acc.Email)
|
t.Fatalf("unexpected email: %q", acc.Email)
|
||||||
}
|
}
|
||||||
if acc.Mobile != "13800138000" {
|
if acc.Mobile != "+8613800138000" {
|
||||||
t.Fatalf("unexpected mobile: %q", acc.Mobile)
|
t.Fatalf("unexpected mobile: %q", acc.Mobile)
|
||||||
}
|
}
|
||||||
if acc.Password != "secret" {
|
if acc.Password != "secret" {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
@@ -67,20 +68,36 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
|
|||||||
var fixture struct {
|
var fixture struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
ToolNames []string `json:"tool_names"`
|
ToolNames []string `json:"tool_names"`
|
||||||
|
Mode string `json:"mode"`
|
||||||
}
|
}
|
||||||
mustLoadJSON(t, fixturePath, &fixture)
|
mustLoadJSON(t, fixturePath, &fixture)
|
||||||
|
|
||||||
var expected struct {
|
var expected struct {
|
||||||
Calls []util.ParsedToolCall `json:"calls"`
|
Calls []util.ParsedToolCall `json:"calls"`
|
||||||
|
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
|
||||||
|
RejectedByPolicy bool `json:"rejectedByPolicy"`
|
||||||
|
RejectedToolNames []string `json:"rejectedToolNames"`
|
||||||
}
|
}
|
||||||
mustLoadJSON(t, expectedPath, &expected)
|
mustLoadJSON(t, expectedPath, &expected)
|
||||||
|
|
||||||
got := util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
var got util.ToolCallParseResult
|
||||||
if len(got) == 0 && len(expected.Calls) == 0 {
|
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||||
continue
|
case "standalone":
|
||||||
|
got = util.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||||
|
default:
|
||||||
|
got = util.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, expected.Calls) {
|
if got.Calls == nil {
|
||||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls)
|
got.Calls = []util.ParsedToolCall{}
|
||||||
|
}
|
||||||
|
if got.RejectedToolNames == nil {
|
||||||
|
got.RejectedToolNames = []string{}
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
|
||||||
|
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
|
||||||
|
got.RejectedByPolicy != expected.RejectedByPolicy ||
|
||||||
|
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
|
||||||
|
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ func (a Account) Identifier() string {
|
|||||||
if strings.TrimSpace(a.Email) != "" {
|
if strings.TrimSpace(a.Email) != "" {
|
||||||
return strings.TrimSpace(a.Email)
|
return strings.TrimSpace(a.Email)
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(a.Mobile) != "" {
|
if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" {
|
||||||
return strings.TrimSpace(a.Mobile)
|
return mobile
|
||||||
}
|
}
|
||||||
// Backward compatibility: old configs may contain token-only accounts.
|
// Backward compatibility: old configs may contain token-only accounts.
|
||||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ func TestConfigCloneNilMaps(t *testing.T) {
|
|||||||
|
|
||||||
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
|
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
|
||||||
acc := Account{Mobile: "13800138000", Token: "tok"}
|
acc := Account{Mobile: "13800138000", Token: "tok"}
|
||||||
if acc.Identifier() != "13800138000" {
|
if acc.Identifier() != "+8613800138000" {
|
||||||
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
|
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
82
internal/config/mobile.go
Normal file
82
internal/config/mobile.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// NormalizeMobileForStorage normalizes user input to a stable storage format.
|
||||||
|
// It keeps existing country codes and auto-prefixes mainland China numbers with +86.
|
||||||
|
func NormalizeMobileForStorage(raw string) string {
|
||||||
|
digits, hasPlus := extractMobileDigits(raw)
|
||||||
|
if digits == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if hasPlus {
|
||||||
|
return "+" + digits
|
||||||
|
}
|
||||||
|
if isChinaMobileWithCountryCode(digits) {
|
||||||
|
return "+86" + digits[2:]
|
||||||
|
}
|
||||||
|
if isChinaMainlandMobileDigits(digits) {
|
||||||
|
return "+86" + digits
|
||||||
|
}
|
||||||
|
// For non-China numbers without a leading +, preserve semantics by adding it.
|
||||||
|
return "+" + digits
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalMobileKey returns the comparison key used by dedupe/matching logic.
|
||||||
|
func CanonicalMobileKey(raw string) string {
|
||||||
|
return NormalizeMobileForStorage(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractMobileDigits(raw string) (digits string, hasPlus bool) {
|
||||||
|
s := strings.TrimSpace(raw)
|
||||||
|
if s == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
goto collect
|
||||||
|
case isMobileSeparator(r):
|
||||||
|
continue
|
||||||
|
case r == '+':
|
||||||
|
hasPlus = true
|
||||||
|
goto collect
|
||||||
|
default:
|
||||||
|
goto collect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
collect:
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s))
|
||||||
|
for _, r := range s {
|
||||||
|
if r >= '0' && r <= '9' {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String(), hasPlus
|
||||||
|
}
|
||||||
|
|
||||||
|
func isChinaMainlandMobileDigits(digits string) bool {
|
||||||
|
if len(digits) != 11 || digits[0] != '1' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return digits[1] >= '3' && digits[1] <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isChinaMobileWithCountryCode(digits string) bool {
|
||||||
|
if len(digits) != 13 || !strings.HasPrefix(digits, "86") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isChinaMainlandMobileDigits(digits[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMobileSeparator(r rune) bool {
|
||||||
|
switch r {
|
||||||
|
case ' ', '\t', '\n', '\r', '-', '(', ')', '.', '/':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
36
internal/config/mobile_test.go
Normal file
36
internal/config/mobile_test.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeMobileForStorageChinaMainlandAddsPlus86(t *testing.T) {
|
||||||
|
if got := NormalizeMobileForStorage("13800138000"); got != "+8613800138000" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForStorageChinaWithCountryCode(t *testing.T) {
|
||||||
|
if got := NormalizeMobileForStorage("8613800138000"); got != "+8613800138000" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForStorageKeepsExistingCountryCode(t *testing.T) {
|
||||||
|
if got := NormalizeMobileForStorage(" +1 (415) 555-2671 "); got != "+14155552671" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanonicalMobileKeyMatchesChinaAliases(t *testing.T) {
|
||||||
|
a := CanonicalMobileKey("+8613800138000")
|
||||||
|
b := CanonicalMobileKey("13800138000")
|
||||||
|
c := CanonicalMobileKey("86 13800138000")
|
||||||
|
if a == "" || a != b || b != c {
|
||||||
|
t.Fatalf("alias mismatch: a=%q b=%q c=%q", a, b, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanonicalMobileKeyEmptyForInvalidInput(t *testing.T) {
|
||||||
|
if got := CanonicalMobileKey("() --"); got != "" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"ds2api/internal/auth"
|
"ds2api/internal/auth"
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
@@ -20,8 +21,9 @@ func (c *Client) Login(ctx context.Context, acc config.Account) (string, error)
|
|||||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||||
payload["email"] = email
|
payload["email"] = email
|
||||||
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
||||||
payload["mobile"] = mobile
|
loginMobile, areaCode := normalizeMobileForLogin(mobile)
|
||||||
payload["area_code"] = nil
|
payload["mobile"] = loginMobile
|
||||||
|
payload["area_code"] = areaCode
|
||||||
} else {
|
} else {
|
||||||
return "", errors.New("missing email/mobile")
|
return "", errors.New("missing email/mobile")
|
||||||
}
|
}
|
||||||
@@ -151,3 +153,26 @@ func isTokenInvalid(status int, code int, msg string) bool {
|
|||||||
}
|
}
|
||||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeMobileForLogin(raw string) (mobile string, areaCode any) {
|
||||||
|
s := strings.TrimSpace(raw)
|
||||||
|
if s == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
hasPlus := strings.HasPrefix(s, "+")
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s))
|
||||||
|
for _, r := range s {
|
||||||
|
if unicode.IsDigit(r) {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
digits := b.String()
|
||||||
|
if digits == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if (hasPlus || strings.HasPrefix(digits, "86")) && strings.HasPrefix(digits, "86") && len(digits) == 13 {
|
||||||
|
return digits[2:], nil
|
||||||
|
}
|
||||||
|
return digits, nil
|
||||||
|
}
|
||||||
|
|||||||
33
internal/deepseek/client_auth_mobile_test.go
Normal file
33
internal/deepseek/client_auth_mobile_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package deepseek
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeMobileForLogin_ChinaWithPlus86(t *testing.T) {
|
||||||
|
mobile, areaCode := normalizeMobileForLogin("+8613800138000")
|
||||||
|
if mobile != "13800138000" {
|
||||||
|
t.Fatalf("unexpected mobile: %q", mobile)
|
||||||
|
}
|
||||||
|
if areaCode != nil {
|
||||||
|
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForLogin_ChinaWith86Prefix(t *testing.T) {
|
||||||
|
mobile, areaCode := normalizeMobileForLogin("8613800138000")
|
||||||
|
if mobile != "13800138000" {
|
||||||
|
t.Fatalf("unexpected mobile: %q", mobile)
|
||||||
|
}
|
||||||
|
if areaCode != nil {
|
||||||
|
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMobileForLogin_KeepPlainDigits(t *testing.T) {
|
||||||
|
mobile, areaCode := normalizeMobileForLogin("13800138000")
|
||||||
|
if mobile != "13800138000" {
|
||||||
|
t.Fatalf("unexpected mobile: %q", mobile)
|
||||||
|
}
|
||||||
|
if areaCode != nil {
|
||||||
|
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||||
finishReason := "stop"
|
finishReason := "stop"
|
||||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||||
if strings.TrimSpace(finalThinking) != "" {
|
if strings.TrimSpace(finalThinking) != "" {
|
||||||
|
|||||||
@@ -11,12 +11,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||||
// Align responses tool-call semantics with chat/completions:
|
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
// as executable tool calls.
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
|
||||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
|
||||||
}
|
|
||||||
exposedOutputText := finalText
|
exposedOutputText := finalText
|
||||||
output := make([]any, 0, 2)
|
output := make([]any, 0, 2)
|
||||||
if len(detected) > 0 {
|
if len(detected) > 0 {
|
||||||
|
|||||||
@@ -71,6 +71,19 @@ func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, cont
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func BuildResponsesTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "response.output_text.done",
|
||||||
|
"id": responseID,
|
||||||
|
"response_id": responseID,
|
||||||
|
"item_id": itemID,
|
||||||
|
"output_index": outputIndex,
|
||||||
|
"content_index": contentIndex,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"type": "response.reasoning.delta",
|
"type": "response.reasoning.delta",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
|
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||||
obj := BuildResponseObject(
|
obj := BuildResponseObject(
|
||||||
"resp_test",
|
"resp_test",
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
@@ -56,17 +56,16 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
|
|||||||
)
|
)
|
||||||
|
|
||||||
outputText, _ := obj["output_text"].(string)
|
outputText, _ := obj["output_text"].(string)
|
||||||
if outputText != "" {
|
if outputText == "" {
|
||||||
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
|
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
output, _ := obj["output"].([]any)
|
output, _ := obj["output"].([]any)
|
||||||
if len(output) != 1 {
|
if len(output) != 1 {
|
||||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||||
}
|
}
|
||||||
first, _ := output[0].(map[string]any)
|
first, _ := output[0].(map[string]any)
|
||||||
if first["type"] != "function_call" {
|
if first["type"] != "message" {
|
||||||
t.Fatalf("expected first output type function_call, got %#v", first["type"])
|
t.Fatalf("expected message output type, got %#v", first["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,7 +126,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
func TestBuildResponseObjectIgnoresToolCallFromThinkingChannel(t *testing.T) {
|
||||||
obj := BuildResponseObject(
|
obj := BuildResponseObject(
|
||||||
"resp_test",
|
"resp_test",
|
||||||
"gpt-4o",
|
"gpt-4o",
|
||||||
@@ -139,10 +138,10 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
|||||||
|
|
||||||
output, _ := obj["output"].([]any)
|
output, _ := obj["output"].([]any)
|
||||||
if len(output) != 1 {
|
if len(output) != 1 {
|
||||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||||
}
|
}
|
||||||
first, _ := output[0].(map[string]any)
|
first, _ := output[0].(map[string]any)
|
||||||
if first["type"] != "function_call" {
|
if first["type"] != "message" {
|
||||||
t.Fatalf("expected output function_call, got %#v", first["type"])
|
t.Fatalf("expected output message, got %#v", first["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ const {
|
|||||||
} = require('./sse_parse');
|
} = require('./sse_parse');
|
||||||
const {
|
const {
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
|
formatIncrementalToolCallDeltas,
|
||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
} = require('./toolcall_policy');
|
} = require('./toolcall_policy');
|
||||||
const {
|
const {
|
||||||
estimateTokens,
|
estimateTokens,
|
||||||
@@ -82,7 +84,9 @@ module.exports.__test = {
|
|||||||
shouldSkipPath,
|
shouldSkipPath,
|
||||||
asString,
|
asString,
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
|
formatIncrementalToolCallDeltas,
|
||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
estimateTokens,
|
estimateTokens,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -68,6 +68,47 @@ function formatIncrementalToolCallDeltas(deltas, idStore) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenNames) {
|
||||||
|
if (!Array.isArray(deltas) || deltas.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const seen = seenNames instanceof Map ? seenNames : new Map();
|
||||||
|
const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== ''));
|
||||||
|
if (allowed.size === 0) {
|
||||||
|
for (const d of deltas) {
|
||||||
|
if (d && typeof d === 'object' && asString(d.name)) {
|
||||||
|
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||||
|
seen.set(index, '__blocked__');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const out = [];
|
||||||
|
for (const d of deltas) {
|
||||||
|
if (!d || typeof d !== 'object') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||||
|
const name = asString(d.name);
|
||||||
|
if (name) {
|
||||||
|
if (!allowed.has(name)) {
|
||||||
|
seen.set(index, '__blocked__');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.set(index, name);
|
||||||
|
out.push(d);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const existing = asString(seen.get(index));
|
||||||
|
if (!existing || existing === '__blocked__') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out.push(d);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
function ensureStreamToolCallID(idStore, index) {
|
function ensureStreamToolCallID(idStore, index) {
|
||||||
const key = Number.isInteger(index) ? index : 0;
|
const key = Number.isInteger(index) ? index : 0;
|
||||||
const existing = idStore.get(key);
|
const existing = idStore.get(key);
|
||||||
@@ -104,4 +145,5 @@ module.exports = {
|
|||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
formatIncrementalToolCallDeltas,
|
formatIncrementalToolCallDeltas,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ const {
|
|||||||
createToolSieveState,
|
createToolSieveState,
|
||||||
processToolSieveChunk,
|
processToolSieveChunk,
|
||||||
flushToolSieve,
|
flushToolSieve,
|
||||||
parseToolCalls,
|
parseStandaloneToolCalls,
|
||||||
formatOpenAIStreamToolCalls,
|
formatOpenAIStreamToolCalls,
|
||||||
} = require('../helpers/stream-tool-sieve');
|
} = require('../helpers/stream-tool-sieve');
|
||||||
const {
|
const {
|
||||||
@@ -24,7 +24,6 @@ const {
|
|||||||
} = require('./token_usage');
|
} = require('./token_usage');
|
||||||
const {
|
const {
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
formatIncrementalToolCallDeltas,
|
|
||||||
} = require('./toolcall_policy');
|
} = require('./toolcall_policy');
|
||||||
const {
|
const {
|
||||||
createChatCompletionEmitter,
|
createChatCompletionEmitter,
|
||||||
@@ -130,7 +129,6 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
|||||||
let thinkingText = '';
|
let thinkingText = '';
|
||||||
let outputText = '';
|
let outputText = '';
|
||||||
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
||||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
|
||||||
const toolSieveState = createToolSieveState();
|
const toolSieveState = createToolSieveState();
|
||||||
let toolCallsEmitted = false;
|
let toolCallsEmitted = false;
|
||||||
const streamToolCallIDs = new Map();
|
const streamToolCallIDs = new Map();
|
||||||
@@ -155,13 +153,18 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
|||||||
await releaseLease();
|
await releaseLease();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const detected = parseToolCalls(outputText, toolNames);
|
const detected = parseStandaloneToolCalls(outputText, toolNames);
|
||||||
if (detected.length > 0 && !toolCallsEmitted) {
|
if (detected.length > 0 && !toolCallsEmitted) {
|
||||||
toolCallsEmitted = true;
|
toolCallsEmitted = true;
|
||||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
|
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected, streamToolCallIDs) });
|
||||||
} else if (toolSieveEnabled) {
|
} else if (toolSieveEnabled) {
|
||||||
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
||||||
for (const evt of tailEvents) {
|
for (const evt of tailEvents) {
|
||||||
|
if (evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0) {
|
||||||
|
toolCallsEmitted = true;
|
||||||
|
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (evt.text) {
|
if (evt.text) {
|
||||||
sendDeltaFrame({ content: evt.text });
|
sendDeltaFrame({ content: evt.text });
|
||||||
}
|
}
|
||||||
@@ -252,17 +255,9 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
|||||||
}
|
}
|
||||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||||
for (const evt of events) {
|
for (const evt of events) {
|
||||||
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
|
|
||||||
if (!emitEarlyToolDeltas) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
toolCallsEmitted = true;
|
|
||||||
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (evt.type === 'tool_calls') {
|
if (evt.type === 'tool_calls') {
|
||||||
toolCallsEmitted = true;
|
toolCallsEmitted = true;
|
||||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
|
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (evt.text) {
|
if (evt.text) {
|
||||||
|
|||||||
@@ -2,13 +2,13 @@
|
|||||||
|
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
|
|
||||||
function formatOpenAIStreamToolCalls(calls) {
|
function formatOpenAIStreamToolCalls(calls, idStore) {
|
||||||
if (!Array.isArray(calls) || calls.length === 0) {
|
if (!Array.isArray(calls) || calls.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
return calls.map((c, idx) => ({
|
return calls.map((c, idx) => ({
|
||||||
index: idx,
|
index: idx,
|
||||||
id: `call_${newCallID()}`,
|
id: ensureStreamToolCallID(idStore, idx),
|
||||||
type: 'function',
|
type: 'function',
|
||||||
function: {
|
function: {
|
||||||
name: c.name,
|
name: c.name,
|
||||||
@@ -17,6 +17,20 @@ function formatOpenAIStreamToolCalls(calls) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function ensureStreamToolCallID(idStore, index) {
|
||||||
|
if (!(idStore instanceof Map)) {
|
||||||
|
return `call_${newCallID()}`;
|
||||||
|
}
|
||||||
|
const key = Number.isInteger(index) ? index : 0;
|
||||||
|
const existing = idStore.get(key);
|
||||||
|
if (existing) {
|
||||||
|
return existing;
|
||||||
|
}
|
||||||
|
const next = `call_${newCallID()}`;
|
||||||
|
idStore.set(key, next);
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
|
||||||
function newCallID() {
|
function newCallID() {
|
||||||
if (typeof crypto.randomUUID === 'function') {
|
if (typeof crypto.randomUUID === 'function') {
|
||||||
return crypto.randomUUID().replace(/-/g, '');
|
return crypto.randomUUID().replace(/-/g, '');
|
||||||
|
|||||||
@@ -1,226 +0,0 @@
|
|||||||
'use strict';
|
|
||||||
|
|
||||||
const {
|
|
||||||
looksLikeToolExampleContext,
|
|
||||||
insideCodeFence,
|
|
||||||
} = require('./state');
|
|
||||||
const {
|
|
||||||
findObjectFieldValueStart,
|
|
||||||
parseJSONStringLiteral,
|
|
||||||
skipSpaces,
|
|
||||||
} = require('./jsonscan');
|
|
||||||
|
|
||||||
function buildIncrementalToolDeltas(state) {
|
|
||||||
const captured = state.capture || '';
|
|
||||||
if (!captured) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (looksLikeToolExampleContext(state.recentTextTail)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const lower = captured.toLowerCase();
|
|
||||||
const keyIdx = lower.indexOf('tool_calls');
|
|
||||||
if (keyIdx < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
|
||||||
if (start < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const callStart = findFirstToolCallObjectStart(captured, keyIdx);
|
|
||||||
if (callStart < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
|
|
||||||
const deltas = [];
|
|
||||||
if (!state.toolName) {
|
|
||||||
const name = extractToolCallName(captured, callStart);
|
|
||||||
if (!name) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
state.toolName = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.toolArgsStart < 0) {
|
|
||||||
const args = findToolCallArgsStart(captured, callStart);
|
|
||||||
if (args) {
|
|
||||||
state.toolArgsString = Boolean(args.stringMode);
|
|
||||||
state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
|
|
||||||
state.toolArgsSent = state.toolArgsStart;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!state.toolNameSent) {
|
|
||||||
if (state.toolArgsStart < 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
state.toolNameSent = true;
|
|
||||||
deltas.push({ index: 0, name: state.toolName });
|
|
||||||
}
|
|
||||||
if (state.toolArgsStart < 0 || state.toolArgsDone) {
|
|
||||||
return deltas;
|
|
||||||
}
|
|
||||||
const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
|
|
||||||
if (!progress) {
|
|
||||||
return deltas;
|
|
||||||
}
|
|
||||||
if (progress.end > state.toolArgsSent) {
|
|
||||||
deltas.push({
|
|
||||||
index: 0,
|
|
||||||
arguments: captured.slice(state.toolArgsSent, progress.end),
|
|
||||||
});
|
|
||||||
state.toolArgsSent = progress.end;
|
|
||||||
}
|
|
||||||
if (progress.complete) {
|
|
||||||
state.toolArgsDone = true;
|
|
||||||
}
|
|
||||||
return deltas;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findFirstToolCallObjectStart(text, keyIdx) {
|
|
||||||
const arrStart = findToolCallsArrayStart(text, keyIdx);
|
|
||||||
if (arrStart < 0) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
const i = skipSpaces(text, arrStart + 1);
|
|
||||||
if (i >= text.length || text[i] !== '{') {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findToolCallsArrayStart(text, keyIdx) {
|
|
||||||
let i = keyIdx + 'tool_calls'.length;
|
|
||||||
while (i < text.length && text[i] !== ':') {
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
if (i >= text.length) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
i = skipSpaces(text, i + 1);
|
|
||||||
if (i >= text.length || text[i] !== '[') {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolCallName(text, callStart) {
|
|
||||||
let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
|
|
||||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
|
||||||
const fnStart = findFunctionObjectStart(text, callStart);
|
|
||||||
if (fnStart < 0) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
|
|
||||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const parsed = parseJSONStringLiteral(text, valueStart);
|
|
||||||
if (!parsed) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return parsed.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findToolCallArgsStart(text, callStart) {
|
|
||||||
const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
|
|
||||||
let valueStart = findObjectFieldValueStart(text, callStart, keys);
|
|
||||||
if (valueStart < 0) {
|
|
||||||
const fnStart = findFunctionObjectStart(text, callStart);
|
|
||||||
if (fnStart < 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
valueStart = findObjectFieldValueStart(text, fnStart, keys);
|
|
||||||
if (valueStart < 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (valueStart >= text.length) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const ch = text[valueStart];
|
|
||||||
if (ch === '{' || ch === '[') {
|
|
||||||
return { start: valueStart, stringMode: false };
|
|
||||||
}
|
|
||||||
if (ch === '"') {
|
|
||||||
return { start: valueStart, stringMode: true };
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
function scanToolCallArgsProgress(text, start, stringMode) {
|
|
||||||
if (start < 0 || start > text.length) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
if (stringMode) {
|
|
||||||
let escaped = false;
|
|
||||||
for (let i = start; i < text.length; i += 1) {
|
|
||||||
const ch = text[i];
|
|
||||||
if (escaped) {
|
|
||||||
escaped = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '\\') {
|
|
||||||
escaped = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '"') {
|
|
||||||
return { end: i, complete: true };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { end: text.length, complete: false };
|
|
||||||
}
|
|
||||||
if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
let depth = 0;
|
|
||||||
let quote = '';
|
|
||||||
let escaped = false;
|
|
||||||
for (let i = start; i < text.length; i += 1) {
|
|
||||||
const ch = text[i];
|
|
||||||
if (quote) {
|
|
||||||
if (escaped) {
|
|
||||||
escaped = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '\\') {
|
|
||||||
escaped = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === quote) {
|
|
||||||
quote = '';
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '"' || ch === "'") {
|
|
||||||
quote = ch;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '{' || ch === '[') {
|
|
||||||
depth += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '}' || ch === ']') {
|
|
||||||
depth -= 1;
|
|
||||||
if (depth === 0) {
|
|
||||||
return { end: i + 1, complete: true };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { end: text.length, complete: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
function findFunctionObjectStart(text, callStart) {
|
|
||||||
const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
|
|
||||||
if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return valueStart;
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = {
|
|
||||||
buildIncrementalToolDeltas,
|
|
||||||
};
|
|
||||||
@@ -10,7 +10,9 @@ const {
|
|||||||
const {
|
const {
|
||||||
extractToolNames,
|
extractToolNames,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
parseStandaloneToolCallsDetailed,
|
||||||
} = require('./parse');
|
} = require('./parse');
|
||||||
const {
|
const {
|
||||||
formatOpenAIStreamToolCalls,
|
formatOpenAIStreamToolCalls,
|
||||||
@@ -22,6 +24,8 @@ module.exports = {
|
|||||||
processToolSieveChunk,
|
processToolSieveChunk,
|
||||||
flushToolSieve,
|
flushToolSieve,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
parseStandaloneToolCallsDetailed,
|
||||||
formatOpenAIStreamToolCalls,
|
formatOpenAIStreamToolCalls,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
toStringSafe,
|
toStringSafe,
|
||||||
looksLikeToolExampleContext,
|
looksLikeToolExampleContext,
|
||||||
} = require('./state');
|
} = require('./state');
|
||||||
const {
|
const {
|
||||||
extractJSONObjectFrom,
|
stripFencedCodeBlocks,
|
||||||
} = require('./jsonscan');
|
buildToolCallCandidates,
|
||||||
|
parseToolCallsPayload,
|
||||||
|
parseMarkupToolCalls,
|
||||||
|
parseTextKVToolCalls,
|
||||||
|
} = require('./parse_payload');
|
||||||
|
|
||||||
|
const TOOL_NAME_LOOSE_PATTERN = /[^a-z0-9]+/g;
|
||||||
|
|
||||||
function extractToolNames(tools) {
|
function extractToolNames(tools) {
|
||||||
if (!Array.isArray(tools) || tools.length === 0) {
|
if (!Array.isArray(tools) || tools.length === 0) {
|
||||||
@@ -29,245 +33,202 @@ function extractToolNames(tools) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function parseToolCalls(text, toolNames) {
|
function parseToolCalls(text, toolNames) {
|
||||||
|
return parseToolCallsDetailed(text, toolNames).calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallsDetailed(text, toolNames) {
|
||||||
|
const result = emptyParseResult();
|
||||||
if (!toStringSafe(text)) {
|
if (!toStringSafe(text)) {
|
||||||
return [];
|
return result;
|
||||||
}
|
}
|
||||||
const sanitized = stripFencedCodeBlocks(text);
|
const sanitized = stripFencedCodeBlocks(text);
|
||||||
if (!toStringSafe(sanitized)) {
|
if (!toStringSafe(sanitized)) {
|
||||||
return [];
|
return result;
|
||||||
}
|
}
|
||||||
|
result.sawToolCallSyntax = looksLikeToolCallSyntax(sanitized);
|
||||||
|
|
||||||
const candidates = buildToolCallCandidates(sanitized);
|
const candidates = buildToolCallCandidates(sanitized);
|
||||||
let parsed = [];
|
let parsed = [];
|
||||||
for (const c of candidates) {
|
for (const c of candidates) {
|
||||||
parsed = parseToolCallsPayload(c);
|
parsed = parseToolCallsPayload(c);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseMarkupToolCalls(c);
|
||||||
|
}
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseTextKVToolCalls(c);
|
||||||
|
}
|
||||||
if (parsed.length > 0) {
|
if (parsed.length > 0) {
|
||||||
|
result.sawToolCallSyntax = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (parsed.length === 0) {
|
if (parsed.length === 0) {
|
||||||
return [];
|
parsed = parseMarkupToolCalls(sanitized);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
parsed = parseTextKVToolCalls(sanitized);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.sawToolCallSyntax = true;
|
||||||
}
|
}
|
||||||
return filterToolCalls(parsed, toolNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
function stripFencedCodeBlocks(text) {
|
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||||
const t = typeof text === 'string' ? text : '';
|
result.calls = filtered.calls;
|
||||||
if (!t) {
|
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||||
return '';
|
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
|
||||||
}
|
return result;
|
||||||
return t.replace(/```[\s\S]*?```/g, ' ');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseStandaloneToolCalls(text, toolNames) {
|
function parseStandaloneToolCalls(text, toolNames) {
|
||||||
|
return parseStandaloneToolCallsDetailed(text, toolNames).calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseStandaloneToolCallsDetailed(text, toolNames) {
|
||||||
|
const result = emptyParseResult();
|
||||||
const trimmed = toStringSafe(text);
|
const trimmed = toStringSafe(text);
|
||||||
if (!trimmed) {
|
if (!trimmed) {
|
||||||
return [];
|
return result;
|
||||||
}
|
}
|
||||||
if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
|
if (trimmed.includes('```')) {
|
||||||
return [];
|
return result;
|
||||||
}
|
}
|
||||||
if (looksLikeToolExampleContext(trimmed)) {
|
if (looksLikeToolExampleContext(trimmed)) {
|
||||||
return [];
|
return result;
|
||||||
}
|
}
|
||||||
const candidates = [trimmed];
|
result.sawToolCallSyntax = looksLikeToolCallSyntax(trimmed);
|
||||||
if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
|
let parsed = parseToolCallsPayload(trimmed);
|
||||||
const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
if (parsed.length === 0) {
|
||||||
if (m && m[1]) {
|
parsed = parseMarkupToolCalls(trimmed);
|
||||||
candidates.push(toStringSafe(m[1]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (const candidate of candidates) {
|
if (parsed.length === 0) {
|
||||||
const c = toStringSafe(candidate);
|
parsed = parseTextKVToolCalls(trimmed);
|
||||||
if (!c) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!c.startsWith('{') && !c.startsWith('[')) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const parsed = parseToolCallsPayload(c);
|
|
||||||
if (parsed.length > 0) {
|
|
||||||
return filterToolCalls(parsed, toolNames);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return [];
|
if (parsed.length === 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.sawToolCallSyntax = true;
|
||||||
|
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||||
|
result.calls = filtered.calls;
|
||||||
|
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||||
|
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
function buildToolCallCandidates(text) {
|
function emptyParseResult() {
|
||||||
const trimmed = toStringSafe(text);
|
|
||||||
const candidates = [trimmed];
|
|
||||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
|
||||||
for (const block of fenced) {
|
|
||||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(toStringSafe(m[1]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
|
||||||
candidates.push(toStringSafe(candidate));
|
|
||||||
}
|
|
||||||
const first = trimmed.indexOf('{');
|
|
||||||
const last = trimmed.lastIndexOf('}');
|
|
||||||
if (first >= 0 && last > first) {
|
|
||||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
|
||||||
}
|
|
||||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
|
||||||
}
|
|
||||||
return [...new Set(candidates.filter(Boolean))];
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolCallObjects(text) {
|
|
||||||
const raw = toStringSafe(text);
|
|
||||||
if (!raw) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const lower = raw.toLowerCase();
|
|
||||||
const out = [];
|
|
||||||
let offset = 0;
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
while (true) {
|
|
||||||
let idx = lower.indexOf('tool_calls', offset);
|
|
||||||
if (idx < 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
|
||||||
while (start >= 0) {
|
|
||||||
const obj = extractJSONObjectFrom(raw, start);
|
|
||||||
if (obj.ok) {
|
|
||||||
out.push(raw.slice(start, obj.end).trim());
|
|
||||||
offset = obj.end;
|
|
||||||
idx = -1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
start = raw.slice(0, start).lastIndexOf('{');
|
|
||||||
}
|
|
||||||
if (idx >= 0) {
|
|
||||||
offset = idx + 'tool_calls'.length;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallsPayload(payload) {
|
|
||||||
let decoded;
|
|
||||||
try {
|
|
||||||
decoded = JSON.parse(payload);
|
|
||||||
} catch (_err) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (Array.isArray(decoded)) {
|
|
||||||
return parseToolCallList(decoded);
|
|
||||||
}
|
|
||||||
if (!decoded || typeof decoded !== 'object') {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (decoded.tool_calls) {
|
|
||||||
return parseToolCallList(decoded.tool_calls);
|
|
||||||
}
|
|
||||||
const one = parseToolCallItem(decoded);
|
|
||||||
return one ? [one] : [];
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallList(v) {
|
|
||||||
if (!Array.isArray(v)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const out = [];
|
|
||||||
for (const item of v) {
|
|
||||||
if (!item || typeof item !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const one = parseToolCallItem(item);
|
|
||||||
if (one) {
|
|
||||||
out.push(one);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallItem(m) {
|
|
||||||
let name = toStringSafe(m.name);
|
|
||||||
let inputRaw = m.input;
|
|
||||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
|
||||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
|
||||||
if (fn) {
|
|
||||||
if (!name) {
|
|
||||||
name = toStringSafe(fn.name);
|
|
||||||
}
|
|
||||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
|
||||||
inputRaw = fn.arguments;
|
|
||||||
hasInput = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!hasInput) {
|
|
||||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
|
||||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
|
||||||
inputRaw = m[k];
|
|
||||||
hasInput = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!name) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return {
|
return {
|
||||||
name,
|
calls: [],
|
||||||
input: parseToolCallInput(inputRaw),
|
sawToolCallSyntax: false,
|
||||||
|
rejectedByPolicy: false,
|
||||||
|
rejectedToolNames: [],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseToolCallInput(v) {
|
function filterToolCallsDetailed(parsed, toolNames) {
|
||||||
if (v == null) {
|
const sourceNames = Array.isArray(toolNames) ? toolNames : [];
|
||||||
return {};
|
const allowed = new Set();
|
||||||
}
|
const allowedCanonical = new Map();
|
||||||
if (typeof v === 'string') {
|
for (const item of sourceNames) {
|
||||||
const raw = toStringSafe(v);
|
const name = toStringSafe(item);
|
||||||
if (!raw) {
|
if (!name) {
|
||||||
return {};
|
continue;
|
||||||
}
|
}
|
||||||
try {
|
allowed.add(name);
|
||||||
const parsed = JSON.parse(raw);
|
const lower = name.toLowerCase();
|
||||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
if (!allowedCanonical.has(lower)) {
|
||||||
return parsed;
|
allowedCanonical.set(lower, name);
|
||||||
}
|
|
||||||
return { _raw: raw };
|
|
||||||
} catch (_err) {
|
|
||||||
return { _raw: raw };
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (typeof v === 'object' && !Array.isArray(v)) {
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(JSON.stringify(v));
|
|
||||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
|
||||||
return parsed;
|
|
||||||
}
|
|
||||||
} catch (_err) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
function filterToolCalls(parsed, toolNames) {
|
if (allowed.size === 0) {
|
||||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
const rejected = [];
|
||||||
const out = [];
|
const seen = new Set();
|
||||||
|
for (const tc of parsed) {
|
||||||
|
if (!tc || !tc.name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (seen.has(tc.name)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.add(tc.name);
|
||||||
|
rejected.push(tc.name);
|
||||||
|
}
|
||||||
|
return { calls: [], rejectedToolNames: rejected };
|
||||||
|
}
|
||||||
|
|
||||||
|
const calls = [];
|
||||||
|
const rejected = [];
|
||||||
|
const seenRejected = new Set();
|
||||||
for (const tc of parsed) {
|
for (const tc of parsed) {
|
||||||
if (!tc || !tc.name) {
|
if (!tc || !tc.name) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
let matchedName = '';
|
||||||
|
if (allowed.has(tc.name)) {
|
||||||
|
matchedName = tc.name;
|
||||||
|
} else {
|
||||||
|
matchedName = resolveAllowedToolName(tc.name, allowed, allowedCanonical);
|
||||||
|
}
|
||||||
|
if (!matchedName) {
|
||||||
|
if (!seenRejected.has(tc.name)) {
|
||||||
|
seenRejected.add(tc.name);
|
||||||
|
rejected.push(tc.name);
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
out.push({ name: tc.name, input: tc.input || {} });
|
calls.push({
|
||||||
|
name: matchedName,
|
||||||
|
input: tc.input && typeof tc.input === 'object' && !Array.isArray(tc.input) ? tc.input : {},
|
||||||
|
});
|
||||||
}
|
}
|
||||||
return out;
|
return { calls, rejectedToolNames: rejected };
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveAllowedToolName(name, allowed, allowedCanonical) {
|
||||||
|
const normalizedName = toStringSafe(name).trim();
|
||||||
|
if (!normalizedName) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
if (allowed.has(normalizedName)) {
|
||||||
|
return normalizedName;
|
||||||
|
}
|
||||||
|
const lower = normalizedName.toLowerCase();
|
||||||
|
if (allowedCanonical.has(lower)) {
|
||||||
|
return allowedCanonical.get(lower);
|
||||||
|
}
|
||||||
|
const idx = lower.lastIndexOf('.');
|
||||||
|
if (idx >= 0 && idx < lower.length - 1) {
|
||||||
|
const tail = lower.slice(idx + 1);
|
||||||
|
if (allowedCanonical.has(tail)) {
|
||||||
|
return allowedCanonical.get(tail);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const loose = lower.replace(TOOL_NAME_LOOSE_PATTERN, '');
|
||||||
|
if (!loose) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
for (const [candidateLower, canonical] of allowedCanonical.entries()) {
|
||||||
|
if (candidateLower.replace(TOOL_NAME_LOOSE_PATTERN, '') === loose) {
|
||||||
|
return canonical;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function looksLikeToolCallSyntax(text) {
|
||||||
|
const lower = toStringSafe(text).toLowerCase();
|
||||||
|
return lower.includes('tool_calls')
|
||||||
|
|| lower.includes('<tool_call')
|
||||||
|
|| lower.includes('<function_call')
|
||||||
|
|| lower.includes('<invoke')
|
||||||
|
|| lower.includes('function.name:');
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
extractToolNames,
|
extractToolNames,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
parseStandaloneToolCallsDetailed,
|
||||||
};
|
};
|
||||||
|
|||||||
363
internal/js/helpers/stream-tool-sieve/parse_payload.js
Normal file
363
internal/js/helpers/stream-tool-sieve/parse_payload.js
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
'use strict';
|
||||||
|
|
||||||
|
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||||
|
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?(tool_call|function_call|invoke)\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||||
|
const TOOL_CALL_MARKUP_SELFCLOSE_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)\/>/gi;
|
||||||
|
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||||
|
const TOOL_CALL_MARKUP_ATTR_PATTERN = /(name|function|tool)\s*=\s*"([^"]+)"/i;
|
||||||
|
const TOOL_CALL_MARKUP_NAME_PATTERNS = [
|
||||||
|
/<(?:[a-z0-9_:-]+:)?name\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?name>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?function\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?function>/i,
|
||||||
|
];
|
||||||
|
const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
|
||||||
|
/<(?:[a-z0-9_:-]+:)?input\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?input>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?arguments>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?argument\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?argument>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameters>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
|
||||||
|
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
|
||||||
|
];
|
||||||
|
const TEXT_KV_NAME_PATTERN = /function\.name:\s*([a-zA-Z0-9_.-]+)/gi;
|
||||||
|
|
||||||
|
const {
|
||||||
|
toStringSafe,
|
||||||
|
} = require('./state');
|
||||||
|
const {
|
||||||
|
extractJSONObjectFrom,
|
||||||
|
} = require('./jsonscan');
|
||||||
|
|
||||||
|
function stripFencedCodeBlocks(text) {
|
||||||
|
const t = typeof text === 'string' ? text : '';
|
||||||
|
if (!t) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
return t.replace(/```[\s\S]*?```/g, ' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildToolCallCandidates(text) {
|
||||||
|
const trimmed = toStringSafe(text);
|
||||||
|
const candidates = [trimmed];
|
||||||
|
|
||||||
|
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
||||||
|
for (const block of fenced) {
|
||||||
|
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||||
|
if (m && m[1]) {
|
||||||
|
candidates.push(toStringSafe(m[1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const candidate of extractToolCallObjects(trimmed)) {
|
||||||
|
candidates.push(toStringSafe(candidate));
|
||||||
|
}
|
||||||
|
|
||||||
|
const first = trimmed.indexOf('{');
|
||||||
|
const last = trimmed.lastIndexOf('}');
|
||||||
|
if (first >= 0 && last > first) {
|
||||||
|
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const m = trimmed.match(TOOL_CALL_PATTERN);
|
||||||
|
if (m && m[1]) {
|
||||||
|
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return [...new Set(candidates.filter(Boolean))];
|
||||||
|
}
|
||||||
|
|
||||||
|
function extractToolCallObjects(text) {
|
||||||
|
const raw = toStringSafe(text);
|
||||||
|
if (!raw) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const lower = raw.toLowerCase();
|
||||||
|
const out = [];
|
||||||
|
let offset = 0;
|
||||||
|
|
||||||
|
// eslint-disable-next-line no-constant-condition
|
||||||
|
while (true) {
|
||||||
|
let idx = lower.indexOf('tool_calls', offset);
|
||||||
|
if (idx < 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let start = raw.slice(0, idx).lastIndexOf('{');
|
||||||
|
while (start >= 0) {
|
||||||
|
const obj = extractJSONObjectFrom(raw, start);
|
||||||
|
if (obj.ok) {
|
||||||
|
out.push(raw.slice(start, obj.end).trim());
|
||||||
|
offset = obj.end;
|
||||||
|
idx = -1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
start = raw.slice(0, start).lastIndexOf('{');
|
||||||
|
}
|
||||||
|
if (idx >= 0) {
|
||||||
|
offset = idx + 'tool_calls'.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallsPayload(payload) {
|
||||||
|
let decoded;
|
||||||
|
try {
|
||||||
|
decoded = JSON.parse(payload);
|
||||||
|
} catch (_err) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(decoded)) {
|
||||||
|
return parseToolCallList(decoded);
|
||||||
|
}
|
||||||
|
if (!decoded || typeof decoded !== 'object') {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
if (decoded.tool_calls) {
|
||||||
|
return parseToolCallList(decoded.tool_calls);
|
||||||
|
}
|
||||||
|
|
||||||
|
const one = parseToolCallItem(decoded);
|
||||||
|
return one ? [one] : [];
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupToolCalls(text) {
|
||||||
|
const raw = toStringSafe(text).trim();
|
||||||
|
if (!raw) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const out = [];
|
||||||
|
for (const m of raw.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
|
||||||
|
const parsed = parseMarkupSingleToolCall(toStringSafe(m[2]).trim(), toStringSafe(m[3]).trim());
|
||||||
|
if (parsed) {
|
||||||
|
out.push(parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const m of raw.matchAll(TOOL_CALL_MARKUP_SELFCLOSE_PATTERN)) {
|
||||||
|
const parsed = parseMarkupSingleToolCall(toStringSafe(m[1]).trim(), '');
|
||||||
|
if (parsed) {
|
||||||
|
out.push(parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseTextKVToolCalls(text) {
|
||||||
|
const raw = toStringSafe(text);
|
||||||
|
if (!raw) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const out = [];
|
||||||
|
const matches = [...raw.matchAll(TEXT_KV_NAME_PATTERN)];
|
||||||
|
if (matches.length === 0) {
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
for (let i = 0; i < matches.length; i += 1) {
|
||||||
|
const match = matches[i];
|
||||||
|
const name = toStringSafe(match[1]).trim();
|
||||||
|
if (!name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const nameEnd = match.index + toStringSafe(match[0]).length;
|
||||||
|
const searchEnd = i + 1 < matches.length ? matches[i + 1].index : raw.length;
|
||||||
|
const searchArea = raw.slice(nameEnd, searchEnd);
|
||||||
|
const argIdx = searchArea.indexOf('function.arguments:');
|
||||||
|
if (argIdx < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const argStart = nameEnd + argIdx + 'function.arguments:'.length;
|
||||||
|
const bracePos = raw.slice(argStart, searchEnd).indexOf('{');
|
||||||
|
if (bracePos < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const objStart = argStart + bracePos;
|
||||||
|
const obj = extractJSONObjectFrom(raw, objStart);
|
||||||
|
if (!obj.ok) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out.push({
|
||||||
|
name,
|
||||||
|
input: parseToolCallInput(raw.slice(objStart, obj.end)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupSingleToolCall(attrs, inner) {
|
||||||
|
const embedded = parseToolCallsPayload(inner);
|
||||||
|
if (embedded.length > 0) {
|
||||||
|
return embedded[0];
|
||||||
|
}
|
||||||
|
let name = '';
|
||||||
|
const attrMatch = attrs.match(TOOL_CALL_MARKUP_ATTR_PATTERN);
|
||||||
|
if (attrMatch && attrMatch[2]) {
|
||||||
|
name = toStringSafe(attrMatch[2]).trim();
|
||||||
|
}
|
||||||
|
if (!name) {
|
||||||
|
name = stripTagText(findMarkupTagValue(inner, TOOL_CALL_MARKUP_NAME_PATTERNS));
|
||||||
|
}
|
||||||
|
if (!name) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = {};
|
||||||
|
const argsRaw = findMarkupTagValue(inner, TOOL_CALL_MARKUP_ARGS_PATTERNS);
|
||||||
|
if (argsRaw) {
|
||||||
|
input = parseMarkupInput(argsRaw);
|
||||||
|
} else {
|
||||||
|
const kv = parseMarkupKVObject(inner);
|
||||||
|
if (Object.keys(kv).length > 0) {
|
||||||
|
input = kv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { name, input };
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupInput(raw) {
|
||||||
|
const s = toStringSafe(raw).trim();
|
||||||
|
if (!s) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const parsed = parseToolCallInput(s);
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
const kv = parseMarkupKVObject(s);
|
||||||
|
if (Object.keys(kv).length > 0) {
|
||||||
|
return kv;
|
||||||
|
}
|
||||||
|
return { _raw: stripTagText(s) };
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseMarkupKVObject(text) {
|
||||||
|
const raw = toStringSafe(text).trim();
|
||||||
|
if (!raw) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const out = {};
|
||||||
|
for (const m of raw.matchAll(TOOL_CALL_MARKUP_KV_PATTERN)) {
|
||||||
|
const key = toStringSafe(m[1]).trim();
|
||||||
|
if (!key) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const valueRaw = stripTagText(m[2]);
|
||||||
|
if (!valueRaw) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
out[key] = JSON.parse(valueRaw);
|
||||||
|
} catch (_err) {
|
||||||
|
out[key] = valueRaw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function stripTagText(text) {
|
||||||
|
return toStringSafe(text).replace(/<[^>]+>/g, ' ').trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
function findMarkupTagValue(text, patterns) {
|
||||||
|
const source = toStringSafe(text);
|
||||||
|
for (const p of patterns) {
|
||||||
|
const m = source.match(p);
|
||||||
|
if (m && m[1]) {
|
||||||
|
return toStringSafe(m[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallList(v) {
|
||||||
|
if (!Array.isArray(v)) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const out = [];
|
||||||
|
for (const item of v) {
|
||||||
|
if (!item || typeof item !== 'object') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const one = parseToolCallItem(item);
|
||||||
|
if (one) {
|
||||||
|
out.push(one);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallItem(m) {
|
||||||
|
let name = toStringSafe(m.name);
|
||||||
|
let inputRaw = m.input;
|
||||||
|
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||||
|
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||||
|
|
||||||
|
if (fn) {
|
||||||
|
if (!name) {
|
||||||
|
name = toStringSafe(fn.name);
|
||||||
|
}
|
||||||
|
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
||||||
|
inputRaw = fn.arguments;
|
||||||
|
hasInput = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasInput) {
|
||||||
|
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
||||||
|
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
||||||
|
inputRaw = m[k];
|
||||||
|
hasInput = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!name) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
name,
|
||||||
|
input: parseToolCallInput(inputRaw),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseToolCallInput(v) {
|
||||||
|
if (v == null) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
if (typeof v === 'string') {
|
||||||
|
const raw = toStringSafe(v);
|
||||||
|
if (!raw) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(raw);
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
return { _raw: raw };
|
||||||
|
} catch (_err) {
|
||||||
|
return { _raw: raw };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (typeof v === 'object' && !Array.isArray(v)) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(JSON.stringify(v));
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
} catch (_err) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
stripFencedCodeBlocks,
|
||||||
|
buildToolCallCandidates,
|
||||||
|
parseToolCallsPayload,
|
||||||
|
parseMarkupToolCalls,
|
||||||
|
parseTextKVToolCalls,
|
||||||
|
};
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const {
|
const {
|
||||||
TOOL_SIEVE_CAPTURE_LIMIT,
|
|
||||||
resetIncrementalToolState,
|
resetIncrementalToolState,
|
||||||
noteText,
|
noteText,
|
||||||
insideCodeFence,
|
insideCodeFence,
|
||||||
} = require('./state');
|
} = require('./state');
|
||||||
const {
|
const {
|
||||||
buildIncrementalToolDeltas,
|
parseStandaloneToolCallsDetailed,
|
||||||
} = require('./incremental');
|
|
||||||
const {
|
|
||||||
parseStandaloneToolCalls,
|
|
||||||
} = require('./parse');
|
} = require('./parse');
|
||||||
const {
|
const {
|
||||||
extractJSONObjectFrom,
|
extractJSONObjectFrom,
|
||||||
@@ -24,64 +20,64 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
|||||||
state.pending += chunk;
|
state.pending += chunk;
|
||||||
}
|
}
|
||||||
const events = [];
|
const events = [];
|
||||||
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
// eslint-disable-next-line no-constant-condition
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||||
|
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||||
|
state.pendingToolRaw = '';
|
||||||
|
state.pendingToolCalls = [];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (state.capturing) {
|
if (state.capturing) {
|
||||||
if (state.pending) {
|
if (state.pending) {
|
||||||
state.capture += state.pending;
|
state.capture += state.pending;
|
||||||
state.pending = '';
|
state.pending = '';
|
||||||
}
|
}
|
||||||
const deltas = buildIncrementalToolDeltas(state);
|
|
||||||
if (deltas.length > 0) {
|
|
||||||
events.push({ type: 'tool_call_deltas', deltas });
|
|
||||||
}
|
|
||||||
const consumed = consumeToolCapture(state, toolNames);
|
const consumed = consumeToolCapture(state, toolNames);
|
||||||
if (!consumed.ready) {
|
if (!consumed.ready) {
|
||||||
if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
|
|
||||||
noteText(state, state.capture);
|
|
||||||
events.push({ type: 'text', text: state.capture });
|
|
||||||
state.capture = '';
|
|
||||||
state.capturing = false;
|
|
||||||
resetIncrementalToolState(state);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
const captured = state.capture;
|
||||||
state.capture = '';
|
state.capture = '';
|
||||||
state.capturing = false;
|
state.capturing = false;
|
||||||
resetIncrementalToolState(state);
|
resetIncrementalToolState(state);
|
||||||
|
|
||||||
|
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||||
|
state.pendingToolRaw = captured;
|
||||||
|
state.pendingToolCalls = consumed.calls;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (consumed.prefix) {
|
if (consumed.prefix) {
|
||||||
noteText(state, consumed.prefix);
|
noteText(state, consumed.prefix);
|
||||||
events.push({ type: 'text', text: consumed.prefix });
|
events.push({ type: 'text', text: consumed.prefix });
|
||||||
}
|
}
|
||||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
|
||||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
|
||||||
}
|
|
||||||
if (consumed.suffix) {
|
if (consumed.suffix) {
|
||||||
state.pending += consumed.suffix;
|
state.pending += consumed.suffix;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!state.pending) {
|
const pending = state.pending || '';
|
||||||
|
if (!pending) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const start = findToolSegmentStart(state.pending);
|
const start = findToolSegmentStart(pending);
|
||||||
if (start >= 0) {
|
if (start >= 0) {
|
||||||
const prefix = state.pending.slice(0, start);
|
const prefix = pending.slice(0, start);
|
||||||
if (prefix) {
|
if (prefix) {
|
||||||
noteText(state, prefix);
|
noteText(state, prefix);
|
||||||
events.push({ type: 'text', text: prefix });
|
events.push({ type: 'text', text: prefix });
|
||||||
}
|
}
|
||||||
state.capture = state.pending.slice(start);
|
|
||||||
state.pending = '';
|
state.pending = '';
|
||||||
|
state.capture += pending.slice(start);
|
||||||
state.capturing = true;
|
state.capturing = true;
|
||||||
resetIncrementalToolState(state);
|
resetIncrementalToolState(state);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const [safe, hold] = splitSafeContentForToolDetection(state.pending);
|
const [safe, hold] = splitSafeContentForToolDetection(pending);
|
||||||
if (!safe) {
|
if (!safe) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -97,6 +93,13 @@ function flushToolSieve(state, toolNames) {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
const events = processToolSieveChunk(state, '', toolNames);
|
const events = processToolSieveChunk(state, '', toolNames);
|
||||||
|
|
||||||
|
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||||
|
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||||
|
state.pendingToolRaw = '';
|
||||||
|
state.pendingToolCalls = [];
|
||||||
|
}
|
||||||
|
|
||||||
if (state.capturing) {
|
if (state.capturing) {
|
||||||
const consumed = consumeToolCapture(state, toolNames);
|
const consumed = consumeToolCapture(state, toolNames);
|
||||||
if (consumed.ready) {
|
if (consumed.ready) {
|
||||||
@@ -119,11 +122,13 @@ function flushToolSieve(state, toolNames) {
|
|||||||
state.capturing = false;
|
state.capturing = false;
|
||||||
resetIncrementalToolState(state);
|
resetIncrementalToolState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.pending) {
|
if (state.pending) {
|
||||||
noteText(state, state.pending);
|
noteText(state, state.pending);
|
||||||
events.push({ type: 'text', text: state.pending });
|
events.push({ type: 'text', text: state.pending });
|
||||||
state.pending = '';
|
state.pending = '';
|
||||||
}
|
}
|
||||||
|
|
||||||
return events;
|
return events;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,11 +168,10 @@ function findToolSegmentStart(s) {
|
|||||||
let offset = 0;
|
let offset = 0;
|
||||||
// eslint-disable-next-line no-constant-condition
|
// eslint-disable-next-line no-constant-condition
|
||||||
while (true) {
|
while (true) {
|
||||||
const keyRel = lower.indexOf('tool_calls', offset);
|
const keyIdx = lower.indexOf('tool_calls', offset);
|
||||||
if (keyRel < 0) {
|
if (keyIdx < 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
const keyIdx = keyRel;
|
|
||||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||||
const candidateStart = start >= 0 ? start : keyIdx;
|
const candidateStart = start >= 0 ? start : keyIdx;
|
||||||
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
||||||
@@ -178,7 +182,7 @@ function findToolSegmentStart(s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function consumeToolCapture(state, toolNames) {
|
function consumeToolCapture(state, toolNames) {
|
||||||
const captured = state.capture;
|
const captured = state.capture || '';
|
||||||
if (!captured) {
|
if (!captured) {
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||||
}
|
}
|
||||||
@@ -195,8 +199,10 @@ function consumeToolCapture(state, toolNames) {
|
|||||||
if (!obj.ok) {
|
if (!obj.ok) {
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||||
}
|
}
|
||||||
|
|
||||||
const prefixPart = captured.slice(0, start);
|
const prefixPart = captured.slice(0, start);
|
||||||
const suffixPart = captured.slice(obj.end);
|
const suffixPart = captured.slice(obj.end);
|
||||||
|
|
||||||
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
||||||
return {
|
return {
|
||||||
ready: true,
|
ready: true,
|
||||||
@@ -205,18 +211,19 @@ function consumeToolCapture(state, toolNames) {
|
|||||||
suffix: '',
|
suffix: '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
const rawParsed = parseStandaloneToolCalls(captured.slice(start, obj.end), []);
|
|
||||||
const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
|
if ((state.recentTextTail || '').trim() !== '' || prefixPart.trim() !== '' || suffixPart.trim() !== '') {
|
||||||
if (parsed.length === 0) {
|
return {
|
||||||
if (rawParsed.length > 0 && Array.isArray(toolNames) && toolNames.length > 0) {
|
ready: true,
|
||||||
return {
|
prefix: captured,
|
||||||
ready: true,
|
calls: [],
|
||||||
prefix: prefixPart,
|
suffix: '',
|
||||||
calls: [],
|
};
|
||||||
suffix: suffixPart,
|
}
|
||||||
};
|
|
||||||
}
|
const parsed = parseStandaloneToolCallsDetailed(captured.slice(start, obj.end), toolNames);
|
||||||
if (state.toolNameSent) {
|
if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) {
|
||||||
|
if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) {
|
||||||
return {
|
return {
|
||||||
ready: true,
|
ready: true,
|
||||||
prefix: prefixPart,
|
prefix: prefixPart,
|
||||||
@@ -231,26 +238,11 @@ function consumeToolCapture(state, toolNames) {
|
|||||||
suffix: '',
|
suffix: '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
if (state.toolNameSent) {
|
|
||||||
if (parsed.length > 1) {
|
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: prefixPart,
|
|
||||||
calls: parsed.slice(1),
|
|
||||||
suffix: suffixPart,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: prefixPart,
|
|
||||||
calls: [],
|
|
||||||
suffix: suffixPart,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
return {
|
||||||
ready: true,
|
ready: true,
|
||||||
prefix: prefixPart,
|
prefix: prefixPart,
|
||||||
calls: parsed,
|
calls: parsed.calls,
|
||||||
suffix: suffixPart,
|
suffix: suffixPart,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
|
|
||||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
||||||
|
|
||||||
function createToolSieveState() {
|
function createToolSieveState() {
|
||||||
@@ -9,6 +8,9 @@ function createToolSieveState() {
|
|||||||
capture: '',
|
capture: '',
|
||||||
capturing: false,
|
capturing: false,
|
||||||
recentTextTail: '',
|
recentTextTail: '',
|
||||||
|
pendingToolRaw: '',
|
||||||
|
pendingToolCalls: [],
|
||||||
|
disableDeltas: false,
|
||||||
toolNameSent: false,
|
toolNameSent: false,
|
||||||
toolName: '',
|
toolName: '',
|
||||||
toolArgsStart: -1,
|
toolArgsStart: -1,
|
||||||
@@ -19,6 +21,7 @@ function createToolSieveState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function resetIncrementalToolState(state) {
|
function resetIncrementalToolState(state) {
|
||||||
|
state.disableDeltas = false;
|
||||||
state.toolNameSent = false;
|
state.toolNameSent = false;
|
||||||
state.toolName = '';
|
state.toolName = '';
|
||||||
state.toolArgsStart = -1;
|
state.toolArgsStart = -1;
|
||||||
@@ -78,7 +81,6 @@ function toStringSafe(v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
TOOL_SIEVE_CAPTURE_LIMIT,
|
|
||||||
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
|
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
|
||||||
createToolSieveState,
|
createToolSieveState,
|
||||||
resetIncrementalToolState,
|
resetIncrementalToolState,
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ func MessagesPrepare(messages []map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NormalizeContent(v any) string {
|
func NormalizeContent(v any) string {
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
switch x := v.(type) {
|
switch x := v.(type) {
|
||||||
case string:
|
case string:
|
||||||
return x
|
return x
|
||||||
@@ -64,11 +67,11 @@ func NormalizeContent(v any) string {
|
|||||||
typeStr, _ := m["type"].(string)
|
typeStr, _ := m["type"].(string)
|
||||||
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
||||||
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
||||||
if txt, ok := m["text"].(string); ok {
|
if txt, ok := m["text"].(string); ok && txt != "" {
|
||||||
parts = append(parts, txt)
|
parts = append(parts, txt)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if txt, ok := m["content"].(string); ok {
|
if txt, ok := m["content"].(string); ok && txt != "" {
|
||||||
parts = append(parts, txt)
|
parts = append(parts, txt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
32
internal/prompt/messages_test.go
Normal file
32
internal/prompt/messages_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package prompt
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeContentNilReturnsEmpty(t *testing.T) {
|
||||||
|
if got := NormalizeContent(nil); got != "" {
|
||||||
|
t.Fatalf("expected empty string for nil content, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessagesPrepareNilContentNoNullLiteral(t *testing.T) {
|
||||||
|
messages := []map[string]any{
|
||||||
|
{"role": "assistant", "content": nil},
|
||||||
|
{"role": "user", "content": "ok"},
|
||||||
|
}
|
||||||
|
got := MessagesPrepare(messages)
|
||||||
|
if got == "" {
|
||||||
|
t.Fatalf("expected non-empty output")
|
||||||
|
}
|
||||||
|
if got == "null" {
|
||||||
|
t.Fatalf("expected no null literal output, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||||
|
got := NormalizeContent([]any{
|
||||||
|
map[string]any{"type": "text", "text": "", "content": "from-content"},
|
||||||
|
})
|
||||||
|
if got != "from-content" {
|
||||||
|
t.Fatalf("expected fallback to content when text is empty, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -57,16 +57,20 @@ func NewApp() *App {
|
|||||||
r.Use(cors)
|
r.Use(cors)
|
||||||
r.Use(timeout(0))
|
r.Use(timeout(0))
|
||||||
|
|
||||||
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||||
})
|
}
|
||||||
r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
readyzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
||||||
})
|
}
|
||||||
|
r.Get("/healthz", healthzHandler)
|
||||||
|
r.Head("/healthz", healthzHandler)
|
||||||
|
r.Get("/readyz", readyzHandler)
|
||||||
|
r.Head("/readyz", readyzHandler)
|
||||||
openai.RegisterRoutes(r, openaiHandler)
|
openai.RegisterRoutes(r, openaiHandler)
|
||||||
claude.RegisterRoutes(r, claudeHandler)
|
claude.RegisterRoutes(r, claudeHandler)
|
||||||
gemini.RegisterRoutes(r, geminiHandler)
|
gemini.RegisterRoutes(r, geminiHandler)
|
||||||
|
|||||||
20
internal/server/router_health_test.go
Normal file
20
internal/server/router_health_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthEndpointsSupportHEAD(t *testing.T) {
|
||||||
|
app := NewApp()
|
||||||
|
|
||||||
|
for _, path := range []string{"/healthz", "/readyz"} {
|
||||||
|
req := httptest.NewRequest(http.MethodHead, path, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
app.Router.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected %s HEAD status 200, got %d", path, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,12 @@ func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
|
|||||||
var m map[string]any
|
var m map[string]any
|
||||||
_ = json.Unmarshal(resp.Body, &m)
|
_ = json.Unmarshal(resp.Body, &m)
|
||||||
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||||
|
|
||||||
|
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/healthz", Retryable: true})
|
||||||
|
if headErr != nil {
|
||||||
|
return headErr
|
||||||
|
}
|
||||||
|
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +35,12 @@ func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
|
|||||||
var m map[string]any
|
var m map[string]any
|
||||||
_ = json.Unmarshal(resp.Body, &m)
|
_ = json.Unmarshal(resp.Body, &m)
|
||||||
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||||
|
|
||||||
|
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/readyz", Retryable: true})
|
||||||
|
if headErr != nil {
|
||||||
|
return headErr
|
||||||
|
}
|
||||||
|
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
161
internal/util/toolcalls_markup.go
Normal file
161
internal/util/toolcalls_markup.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var toolCallMarkupTagNames = []string{"tool_call", "function_call", "invoke"}
|
||||||
|
var toolCallMarkupTagPatternByName = map[string]*regexp.Regexp{
|
||||||
|
"tool_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?tool_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?tool_call>`),
|
||||||
|
"function_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?function_call>`),
|
||||||
|
"invoke": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?invoke>`),
|
||||||
|
}
|
||||||
|
var toolCallMarkupSelfClosingPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)/>`)
|
||||||
|
var toolCallMarkupKVPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)>`)
|
||||||
|
var toolCallMarkupAttrPattern = regexp.MustCompile(`(?is)(name|function|tool)\s*=\s*"([^"]+)"`)
|
||||||
|
var anyTagPattern = regexp.MustCompile(`(?is)<[^>]+>`)
|
||||||
|
var toolCallMarkupNameTagNames = []string{"name", "function"}
|
||||||
|
var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
|
||||||
|
"name": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?name>`),
|
||||||
|
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
|
||||||
|
}
|
||||||
|
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
|
||||||
|
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
|
||||||
|
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
|
||||||
|
"arguments": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?arguments>`),
|
||||||
|
"argument": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?argument\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?argument>`),
|
||||||
|
"parameters": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameters>`),
|
||||||
|
"parameter": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameter>`),
|
||||||
|
"args": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?args\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?args>`),
|
||||||
|
"params": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?params\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?params>`),
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupToolCalls(text string) []ParsedToolCall {
|
||||||
|
trimmed := strings.TrimSpace(text)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]ParsedToolCall, 0)
|
||||||
|
for _, tagName := range toolCallMarkupTagNames {
|
||||||
|
pattern := toolCallMarkupTagPatternByName[tagName]
|
||||||
|
for _, m := range pattern.FindAllStringSubmatch(trimmed, -1) {
|
||||||
|
if len(m) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
attrs := strings.TrimSpace(m[1])
|
||||||
|
inner := strings.TrimSpace(m[2])
|
||||||
|
if parsed := parseMarkupSingleToolCall(attrs, inner); parsed.Name != "" {
|
||||||
|
out = append(out, parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range toolCallMarkupSelfClosingPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||||
|
if len(m) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if parsed := parseMarkupSingleToolCall(strings.TrimSpace(m[1]), ""); parsed.Name != "" {
|
||||||
|
out = append(out, parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||||
|
if parsed := parseToolCallsPayload(inner); len(parsed) > 0 {
|
||||||
|
return parsed[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
name := ""
|
||||||
|
if m := toolCallMarkupAttrPattern.FindStringSubmatch(attrs); len(m) >= 3 {
|
||||||
|
name = strings.TrimSpace(m[2])
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = findMarkupTagValue(inner, toolCallMarkupNameTagNames, toolCallMarkupNamePatternByTag)
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}
|
||||||
|
}
|
||||||
|
|
||||||
|
input := map[string]any{}
|
||||||
|
if argsRaw := findMarkupTagValue(inner, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
|
||||||
|
input = parseMarkupInput(argsRaw)
|
||||||
|
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
|
||||||
|
input = kv
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupInput(raw string) map[string]any {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
return map[string]any{"_raw": stripTagText(raw)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMarkupKVObject(text string) map[string]any {
|
||||||
|
matches := toolCallMarkupKVPattern.FindAllStringSubmatch(strings.TrimSpace(text), -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := map[string]any{}
|
||||||
|
for _, m := range matches {
|
||||||
|
if len(m) < 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(m[1])
|
||||||
|
endKey := strings.TrimSpace(m[3])
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(key, endKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value := strings.TrimSpace(stripTagText(m[2]))
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var jsonValue any
|
||||||
|
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||||
|
out[key] = jsonValue
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripTagText(text string) string {
|
||||||
|
return strings.TrimSpace(anyTagPattern.ReplaceAllString(text, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func findMarkupTagValue(text string, tagNames []string, patternByTag map[string]*regexp.Regexp) string {
|
||||||
|
for _, tag := range tagNames {
|
||||||
|
pattern := patternByTag[tag]
|
||||||
|
if pattern == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m := pattern.FindStringSubmatch(text); len(m) >= 2 {
|
||||||
|
value := strings.TrimSpace(m[1])
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
33
internal/util/toolcalls_name_match.go
Normal file
33
internal/util/toolcalls_name_match.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
|
||||||
|
|
||||||
|
func resolveAllowedToolNameWithLooseMatch(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
|
||||||
|
if _, ok := allowed[name]; ok {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(name))
|
||||||
|
if canonical, ok := allowedCanonical[lower]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
|
||||||
|
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
|
||||||
|
if loose == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for candidateLower, canonical := range allowedCanonical {
|
||||||
|
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -30,19 +30,36 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
|
|||||||
if strings.TrimSpace(text) == "" {
|
if strings.TrimSpace(text) == "" {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
|
result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
|
||||||
|
|
||||||
candidates := buildToolCallCandidates(text)
|
candidates := buildToolCallCandidates(text)
|
||||||
var parsed []ParsedToolCall
|
var parsed []ParsedToolCall
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
tc := parseToolCallsPayload(candidate)
|
||||||
|
if len(tc) == 0 {
|
||||||
|
tc = parseXMLToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(tc) == 0 {
|
||||||
|
tc = parseMarkupToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(tc) == 0 {
|
||||||
|
tc = parseTextKVToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(tc) > 0 {
|
||||||
parsed = tc
|
parsed = tc
|
||||||
result.SawToolCallSyntax = true
|
result.SawToolCallSyntax = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(parsed) == 0 {
|
if len(parsed) == 0 {
|
||||||
return result
|
parsed = parseXMLToolCalls(text)
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseTextKVToolCalls(text)
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.SawToolCallSyntax = true
|
||||||
}
|
}
|
||||||
|
|
||||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||||
@@ -65,17 +82,24 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string)
|
|||||||
if looksLikeToolExampleContext(trimmed) {
|
if looksLikeToolExampleContext(trimmed) {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
|
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
|
||||||
candidates := []string{trimmed}
|
candidates := []string{trimmed}
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
candidate = strings.TrimSpace(candidate)
|
candidate = strings.TrimSpace(candidate)
|
||||||
if candidate == "" {
|
if candidate == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") {
|
parsed := parseToolCallsPayload(candidate)
|
||||||
continue
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseXMLToolCalls(candidate)
|
||||||
}
|
}
|
||||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseMarkupToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(parsed) == 0 {
|
||||||
|
parsed = parseTextKVToolCalls(candidate)
|
||||||
|
}
|
||||||
|
if len(parsed) > 0 {
|
||||||
result.SawToolCallSyntax = true
|
result.SawToolCallSyntax = true
|
||||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||||
result.Calls = calls
|
result.Calls = calls
|
||||||
@@ -103,32 +127,32 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
|||||||
}
|
}
|
||||||
if len(allowed) == 0 {
|
if len(allowed) == 0 {
|
||||||
rejectedSet := map[string]struct{}{}
|
rejectedSet := map[string]struct{}{}
|
||||||
|
rejected := make([]string, 0, len(parsed))
|
||||||
for _, tc := range parsed {
|
for _, tc := range parsed {
|
||||||
if tc.Name == "" {
|
if tc.Name == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if _, ok := rejectedSet[tc.Name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
rejectedSet[tc.Name] = struct{}{}
|
rejectedSet[tc.Name] = struct{}{}
|
||||||
}
|
rejected = append(rejected, tc.Name)
|
||||||
rejected := make([]string, 0, len(rejectedSet))
|
|
||||||
for name := range rejectedSet {
|
|
||||||
rejected = append(rejected, name)
|
|
||||||
}
|
}
|
||||||
return nil, rejected
|
return nil, rejected
|
||||||
}
|
}
|
||||||
out := make([]ParsedToolCall, 0, len(parsed))
|
out := make([]ParsedToolCall, 0, len(parsed))
|
||||||
rejectedSet := map[string]struct{}{}
|
rejectedSet := map[string]struct{}{}
|
||||||
|
rejected := make([]string, 0)
|
||||||
for _, tc := range parsed {
|
for _, tc := range parsed {
|
||||||
if tc.Name == "" {
|
if tc.Name == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
matchedName := ""
|
matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
|
||||||
if _, ok := allowed[tc.Name]; ok {
|
|
||||||
matchedName = tc.Name
|
|
||||||
} else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok {
|
|
||||||
matchedName = canonical
|
|
||||||
}
|
|
||||||
if matchedName == "" {
|
if matchedName == "" {
|
||||||
rejectedSet[tc.Name] = struct{}{}
|
if _, ok := rejectedSet[tc.Name]; !ok {
|
||||||
|
rejectedSet[tc.Name] = struct{}{}
|
||||||
|
rejected = append(rejected, tc.Name)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tc.Name = matchedName
|
tc.Name = matchedName
|
||||||
@@ -137,13 +161,13 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
|||||||
}
|
}
|
||||||
out = append(out, tc)
|
out = append(out, tc)
|
||||||
}
|
}
|
||||||
rejected := make([]string, 0, len(rejectedSet))
|
|
||||||
for name := range rejectedSet {
|
|
||||||
rejected = append(rejected, name)
|
|
||||||
}
|
|
||||||
return out, rejected
|
return out, rejected
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
|
||||||
|
return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical)
|
||||||
|
}
|
||||||
|
|
||||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||||
var decoded any
|
var decoded any
|
||||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||||
@@ -163,6 +187,15 @@ func parseToolCallsPayload(payload string) []ParsedToolCall {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func looksLikeToolCallSyntax(text string) bool {
|
||||||
|
lower := strings.ToLower(text)
|
||||||
|
return strings.Contains(lower, "tool_calls") ||
|
||||||
|
strings.Contains(lower, "<tool_call") ||
|
||||||
|
strings.Contains(lower, "<function_call") ||
|
||||||
|
strings.Contains(lower, "<invoke") ||
|
||||||
|
strings.Contains(lower, "function.name:")
|
||||||
|
}
|
||||||
|
|
||||||
func parseToolCallList(v any) []ParsedToolCall {
|
func parseToolCallList(v any) []ParsedToolCall {
|
||||||
items, ok := v.([]any)
|
items, ok := v.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
235
internal/util/toolcalls_parse_markup.go
Normal file
235
internal/util/toolcalls_parse_markup.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var xmlToolCallPattern = regexp.MustCompile(`(?is)<tool_call>\s*(.*?)\s*</tool_call>`)
|
||||||
|
var functionCallPattern = regexp.MustCompile(`(?is)<function_call>\s*([^<]+?)\s*</function_call>`)
|
||||||
|
var functionParamPattern = regexp.MustCompile(`(?is)<function\s+parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</function\s+parameter>`)
|
||||||
|
var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?function_call>`)
|
||||||
|
var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?argument>`)
|
||||||
|
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
|
||||||
|
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
|
||||||
|
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
|
||||||
|
|
||||||
|
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||||
|
matches := xmlToolCallPattern.FindAllString(text, -1)
|
||||||
|
out := make([]ParsedToolCall, 0, len(matches)+1)
|
||||||
|
for _, block := range matches {
|
||||||
|
call, ok := parseSingleXMLToolCall(block)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, call)
|
||||||
|
}
|
||||||
|
if len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
if call, ok := parseFunctionCallTagStyle(text); ok {
|
||||||
|
return []ParsedToolCall{call}
|
||||||
|
}
|
||||||
|
if calls := parseAntmlFunctionCallStyles(text); len(calls) > 0 {
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
if call, ok := parseInvokeFunctionCallStyle(text); ok {
|
||||||
|
return []ParsedToolCall{call}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||||
|
inner := strings.TrimSpace(block)
|
||||||
|
inner = strings.TrimPrefix(inner, "<tool_call>")
|
||||||
|
inner = strings.TrimSuffix(inner, "</tool_call>")
|
||||||
|
inner = strings.TrimSpace(inner)
|
||||||
|
if strings.HasPrefix(inner, "{") {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
|
||||||
|
name := strings.TrimSpace(asString(payload["tool"]))
|
||||||
|
if name == "" {
|
||||||
|
name = strings.TrimSpace(asString(payload["tool_name"]))
|
||||||
|
}
|
||||||
|
if name != "" {
|
||||||
|
input := map[string]any{}
|
||||||
|
if params, ok := payload["params"].(map[string]any); ok {
|
||||||
|
input = params
|
||||||
|
} else if params, ok := payload["parameters"].(map[string]any); ok {
|
||||||
|
input = params
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dec := xml.NewDecoder(strings.NewReader(block))
|
||||||
|
name := ""
|
||||||
|
params := map[string]any{}
|
||||||
|
inParams := false
|
||||||
|
inTool := false
|
||||||
|
for {
|
||||||
|
tok, err := dec.Token()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
switch t := tok.(type) {
|
||||||
|
case xml.StartElement:
|
||||||
|
tag := strings.ToLower(t.Name.Local)
|
||||||
|
switch tag {
|
||||||
|
case "tool":
|
||||||
|
inTool = true
|
||||||
|
for _, attr := range t.Attr {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") && strings.TrimSpace(name) == "" {
|
||||||
|
name = strings.TrimSpace(attr.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "parameters":
|
||||||
|
inParams = true
|
||||||
|
case "tool_name", "name":
|
||||||
|
var v string
|
||||||
|
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||||
|
name = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
case "input", "arguments", "argument", "args", "params":
|
||||||
|
var v string
|
||||||
|
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||||
|
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||||
|
for k, vv := range parsed {
|
||||||
|
params[k] = vv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if inParams || inTool {
|
||||||
|
var v string
|
||||||
|
if err := dec.DecodeElement(&v, &t); err == nil {
|
||||||
|
params[t.Name.Local] = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case xml.EndElement:
|
||||||
|
tag := strings.ToLower(t.Name.Local)
|
||||||
|
if tag == "parameters" {
|
||||||
|
inParams = false
|
||||||
|
}
|
||||||
|
if tag == "tool" {
|
||||||
|
inTool = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||||
|
m := functionCallPattern.FindStringSubmatch(text)
|
||||||
|
if len(m) < 2 {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(m[1])
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
input := map[string]any{}
|
||||||
|
for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) {
|
||||||
|
if len(pm) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(pm[1])
|
||||||
|
val := strings.TrimSpace(pm[2])
|
||||||
|
if key != "" {
|
||||||
|
input[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAntmlFunctionCallStyles(text string) []ParsedToolCall {
|
||||||
|
matches := antmlFunctionCallPattern.FindAllStringSubmatch(text, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]ParsedToolCall, 0, len(matches))
|
||||||
|
for _, m := range matches {
|
||||||
|
if call, ok := parseSingleAntmlFunctionCallMatch(m); ok {
|
||||||
|
out = append(out, call)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||||
|
if len(m) < 3 {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(m[1])
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(m[2])
|
||||||
|
input := map[string]any{}
|
||||||
|
if strings.HasPrefix(body, "{") {
|
||||||
|
if err := json.Unmarshal([]byte(body), &input); err == nil {
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 {
|
||||||
|
if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil {
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) {
|
||||||
|
if len(am) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
k := strings.TrimSpace(am[1])
|
||||||
|
v := strings.TrimSpace(am[2])
|
||||||
|
if k != "" {
|
||||||
|
input[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||||
|
m := invokeCallPattern.FindStringSubmatch(text)
|
||||||
|
if len(m) < 3 {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(m[1])
|
||||||
|
if name == "" {
|
||||||
|
return ParsedToolCall{}, false
|
||||||
|
}
|
||||||
|
input := map[string]any{}
|
||||||
|
for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) {
|
||||||
|
if len(pm) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
k := strings.TrimSpace(pm[1])
|
||||||
|
v := strings.TrimSpace(pm[2])
|
||||||
|
if k != "" {
|
||||||
|
input[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(input) == 0 {
|
||||||
|
if argsRaw := findMarkupTagValue(m[2], toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
|
||||||
|
input = parseMarkupInput(argsRaw)
|
||||||
|
} else if kv := parseMarkupKVObject(m[2]); len(kv) > 0 {
|
||||||
|
input = kv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParsedToolCall{Name: name, Input: input}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func asString(v any) string {
|
||||||
|
s, _ := v.(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
@@ -115,3 +115,167 @@ func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
|||||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
|
||||||
|
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
|
||||||
|
calls := ParseToolCalls(text, []string{"search_web"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "search_web" {
|
||||||
|
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
|
||||||
|
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "read_file" {
|
||||||
|
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
|
||||||
|
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command><description>show cwd</description></parameters></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsDetailedMarksXMLToolCallSyntax(t *testing.T) {
|
||||||
|
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`
|
||||||
|
res := ParseToolCallsDetailed(text, []string{"bash"})
|
||||||
|
if !res.SawToolCallSyntax {
|
||||||
|
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||||
|
}
|
||||||
|
if len(res.Calls) != 1 {
|
||||||
|
t.Fatalf("expected one parsed call, got %#v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsClaudeXMLJSONToolCall(t *testing.T) {
|
||||||
|
text := `<tool_call>{"tool":"Bash","params":{"command":"pwd","description":"show cwd"}}</tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsFunctionCallTagStyle(t *testing.T) {
|
||||||
|
text := `<function_call>Bash</function_call><function parameter name="command">ls -la</function parameter><function parameter name="description">list</function parameter>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "ls -la" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsAntmlFunctionCallStyle(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call name="Bash">{"command":"pwd","description":"x"}</antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsAntmlArgumentStyle(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument><antml:argument name="description">x</antml:argument></antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
|
||||||
|
text := `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter><parameter name="description">d</parameter></invoke></function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) {
|
||||||
|
text := `<tool_call><tool name="Bash"><command>pwd</command><description>show cwd</description></tool></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call id="x" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" {
|
||||||
|
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "pwd" {
|
||||||
|
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsMultipleAntmlFunctionCalls(t *testing.T) {
|
||||||
|
text := `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call><antml:function_call id="2" function="Read"><antml:parameters>{"file_path":"README.md"}</antml:parameters></antml:function_call></antml:function_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"bash", "read"})
|
||||||
|
if len(calls) != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "bash" || calls[1].Name != "read" {
|
||||||
|
t.Fatalf("expected canonical names [bash read], got %#v", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||||
|
text := `<tool_call><name>read_file</function><arguments>{"path":"README.md"}</arguments></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected mismatched tags to be rejected, got %#v", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
55
internal/util/toolcalls_textkv.go
Normal file
55
internal/util/toolcalls_textkv.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var textKVNamePattern = regexp.MustCompile(`(?is)function\.name:\s*([a-zA-Z0-9_\-.]+)`)
|
||||||
|
|
||||||
|
func parseTextKVToolCalls(text string) []ParsedToolCall {
|
||||||
|
var out []ParsedToolCall
|
||||||
|
matches := textKVNamePattern.FindAllStringSubmatchIndex(text, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, match := range matches {
|
||||||
|
name := text[match[2]:match[3]]
|
||||||
|
|
||||||
|
offset := match[1]
|
||||||
|
endSearch := len(text)
|
||||||
|
if i+1 < len(matches) {
|
||||||
|
endSearch = matches[i+1][0]
|
||||||
|
}
|
||||||
|
|
||||||
|
searchArea := text[offset:endSearch]
|
||||||
|
argIdx := strings.Index(searchArea, "function.arguments:")
|
||||||
|
if argIdx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
startIdx := offset + argIdx + len("function.arguments:")
|
||||||
|
braceIdx := strings.IndexByte(text[startIdx:endSearch], '{')
|
||||||
|
if braceIdx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
actualStart := startIdx + braceIdx
|
||||||
|
objJson, _, ok := extractJSONObject(text, actualStart)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
input := parseToolCallInput(objJson)
|
||||||
|
out = append(out, ParsedToolCall{
|
||||||
|
Name: name,
|
||||||
|
Input: input,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
63
internal/util/toolcalls_textkv_test.go
Normal file
63
internal/util/toolcalls_textkv_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseTextKVToolCalls_Basic(t *testing.T) {
|
||||||
|
text := `
|
||||||
|
[TOOL_CALL_HISTORY]
|
||||||
|
status: already_called
|
||||||
|
origin: assistant
|
||||||
|
not_user_input: true
|
||||||
|
tool_call_id: call_3fcd15235eb94f7eae3a8de5a9cfa36b
|
||||||
|
function.name: execute_command
|
||||||
|
function.arguments: {"command":"cd scripts && python check_syntax.py example.py","cwd":null,"timeout":30}
|
||||||
|
[/TOOL_CALL_HISTORY]
|
||||||
|
|
||||||
|
Some other text thinking...
|
||||||
|
`
|
||||||
|
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Name != "execute_command" {
|
||||||
|
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["command"] != "cd scripts && python check_syntax.py example.py" {
|
||||||
|
t.Fatalf("unexpected command arg: %v", calls[0].Input["command"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextKVToolCalls_Multiple(t *testing.T) {
|
||||||
|
text := `
|
||||||
|
function.name: read_file
|
||||||
|
function.arguments: {
|
||||||
|
"path": "abc.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
function.name: bash
|
||||||
|
function.arguments: {"command": "ls"}
|
||||||
|
`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file", "bash"})
|
||||||
|
if len(calls) != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Name != "read_file" {
|
||||||
|
t.Fatalf("unexpected 1st name: %s", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[1].Name != "bash" {
|
||||||
|
t.Fatalf("unexpected 2nd name: %s", calls[1].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextKVToolCalls_Standalone(t *testing.T) {
|
||||||
|
text := "function.name: read_file\nfunction.arguments: {\"path\":\"README.md\"}"
|
||||||
|
calls := ParseStandaloneToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Name != "read_file" {
|
||||||
|
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
101
misc/deepseek_functioncalling_bug/report.md
Normal file
101
misc/deepseek_functioncalling_bug/report.md
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# DeepSeek Function Calling 缺陷分析与 ds2api 的增强修复策略
|
||||||
|
|
||||||
|
> **相关 PR**: #74 (代码核心实现) 与 #75 (Merge to dev)
|
||||||
|
> **问题背景**: 解决因包括 DeepSeek 在内的部分模型在函数调用(Function Calling/Tool Call)表现不够“规范”,从而导致工具调用失败的问题。
|
||||||
|
|
||||||
|
## 一、底层架构对比:为什么会产生 Function Calling 缺陷?
|
||||||
|
|
||||||
|
在探讨缺陷前,我们需要理解两种 Function Calling 的底层结构差异:
|
||||||
|
|
||||||
|
### 1. OpenAI 的原生结构化返回 (API 级分离)
|
||||||
|
在 OpenAI 的规范中,**聊天文字与工具调用是在底层的 JSON 结构中被硬性拆分的**:
|
||||||
|
* 聊天废话存放在 `response.choices[0].message.content` 里。
|
||||||
|
* 工具请求存放在单独的数组 `response.choices[0].message.tool_calls` 里。
|
||||||
|
|
||||||
|
**优势:** 这种设计对客户端极其友好。客户端只需判断 `tool_calls` 是否为空,就能决定是执行代码还是渲染文字。它支持同时并发多个工具请求,且底层的生成殷勤被严格训练和约束,极少抛出语法错误的 JSON。
|
||||||
|
|
||||||
|
### 2. DeepSeek 等模型的“单文本流”机制
|
||||||
|
相比之下,部分未经深度专门微调的模型(或者在特定的通信适配层中),它们依然倾向于把一切内容打包成一个纯文本流吐出。这就是为什么它们的输出往往不仅包含了本该属于 `tool_calls` 结构里的 JSON,还会像个“老实人”一样夹杂了属于 `content` 里的散文。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、DeepSeek 在 Function Calling 上的特定缺陷表现
|
||||||
|
|
||||||
|
相比于 OpenAI 严格遵循 API 约定的原生结构,DeepSeek 等开源/国产推理模型在工具调用时,经常会暴露出以下三种典型的“不守规矩”的输出行为:
|
||||||
|
|
||||||
|
### 1. 混合输出:散文文本与工具 JSON 混杂 (Mixed Prose Streams)
|
||||||
|
当应用要求模型直接返回工具请求时,DeepSeek 有时候会**“忍不住想和用户搭话”**。
|
||||||
|
它常常前置一段解释性废话,中间插入工具调用的 JSON 参数,并在末尾再补上一句总结:
|
||||||
|
```text
|
||||||
|
好的,我这就帮你读取 README.md 的内容:
|
||||||
|
{"tool_calls":[{"name":"read_file","input":{"path":"README.md"}}]}
|
||||||
|
请稍等片刻,我马上把它读出来。
|
||||||
|
```
|
||||||
|
**旧版系统痛点:**
|
||||||
|
原有的代码存在**严格模式(Strict Mode)**校验:
|
||||||
|
```go
|
||||||
|
// 如果解析到的 JSON 块前后存在任何非空字符串,就放弃当作工具调用!
|
||||||
|
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" ... {
|
||||||
|
return captured, nil, "", true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
这直接导致上述结构被网关认定是一段“普通聊天”,直接原封不动地返回给用户,这直接干挂了后续的工具自动执行流程。
|
||||||
|
|
||||||
|
### 2. 工具名格式幻觉:擅自修改或前缀化工具名称
|
||||||
|
由于 DeepSeek 的预训练数据中有大量的代码和不同的平台结构,它在回复工具名称时,常常无法忠实于 System Prompt 中提供的纯命名(也就是 `name: "read_file"`),而是加上前缀或者拼写变形,例如:
|
||||||
|
* `{"name": "mcp.search_web"}` (自带命名空间)
|
||||||
|
* `{"name": "tools.read_file"}`
|
||||||
|
* `{"name": "search-web"}` (下划线变成了中划线)
|
||||||
|
|
||||||
|
**旧版系统痛点:**
|
||||||
|
旧版系统对于工具名的匹配几乎只有“绝对相等”的字典级比对,只要差了一个字符或加了前缀,就会由于找不到合法工具而直接失败。
|
||||||
|
|
||||||
|
### 3. Role 角色的非标准返回
|
||||||
|
在部分工具通信流的响应中,返回的内容其所属的 `role` 没有被标准化处理,可能携带意料之外的属性,或是与下游严格比对出现冲突。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、PR #74 的代码增强修复方案
|
||||||
|
|
||||||
|
为了解决大模型这种自身的不规范行为,PR #74 在系统的中间层网关联入了一个**极其包容的容错引擎**。它并不强制要求模型“改过自新”,而是主动做了以下三块增强:
|
||||||
|
|
||||||
|
### 1. 从流中分离混合内容(废除 Strict Mode)
|
||||||
|
修改了 `internal/adapter/openai/tool_sieve_core.go`。
|
||||||
|
取消了前后包裹文本的拦截逻辑。当系统扫描到流式结构中有完整的 `{"tool_calls":...}` 时,它会将废话和 JSON 分发到不同的事件流中:
|
||||||
|
```go
|
||||||
|
if prefix != "" {
|
||||||
|
// 将前面的“好的,帮你读文件”剥离出来作为常规文本输出
|
||||||
|
state.noteText(prefix)
|
||||||
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
|
}
|
||||||
|
// 捕获并拦截中间的工具请求,进行背后执行
|
||||||
|
state.pendingToolCalls = calls
|
||||||
|
```
|
||||||
|
**效果:** 用户的屏幕上只能看到正常的文字交流,而后端的工具也会立刻挂载。
|
||||||
|
|
||||||
|
### 2. 多级宽容匹配引擎 (Resolve Allowed Tool Name)
|
||||||
|
在 `internal/util/toolcalls_parse.go` 中,新增了一个由严到松降级匹配的强大漏斗策略函数 `resolveAllowedToolName`:
|
||||||
|
|
||||||
|
1. **绝对匹配**:和以前一样,`read_file` == `read_file`。
|
||||||
|
2. **忽略大小写**:`Read_File` 算作合法。
|
||||||
|
3. **命名空间抹除**:通过寻找最后一个 `.` 来剥离前缀,强制将 `mcp.search_web` 还原出真实的 `search_web`。
|
||||||
|
4. **终极正则清洗**:
|
||||||
|
引入 `var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)`。
|
||||||
|
这个正则剥离了字符串里所有的符号、空格、格式符。
|
||||||
|
将传入的 `read-file` 洗除符号成为 `readfile`,并去和系统中所有合法工具同样清洗后的版本进行比较。只要核心字母一致,即算作匹配成功。
|
||||||
|
|
||||||
|
### 3. Role 归一化 (Normalize OpenAIRoleForPrompt)
|
||||||
|
在 `internal/adapter/openai/responses_input_items.go` 等处,引入了特定的 `normalizeOpenAIRoleForPrompt(role)` 清洗,保证输入和传递给上游的 Role 枚举始终受控,消除了因为意外的身份字段传参崩溃。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 报告总结与 tool_sieve 的本质作用
|
||||||
|
|
||||||
|
PR #74 / #75 并没有从模型本身开刀,而是基于**网关应足够健壮**的设计哲学。
|
||||||
|
|
||||||
|
**其实整个增强实现,本质上实现了一个名为 `tool_sieve` (工具筛子) 的中间层网关。**
|
||||||
|
面对 DeepSeek 这种吐出一团混合了聊天文字与 JSON 面团的“不标准”数据流,`tool_sieve` 就像一个勤劳的高精度筛子,不仅人工揉开了面团:
|
||||||
|
1. 它把散文分拣出来,塞回标准结构的 `content` 字段去展示;
|
||||||
|
2. 剥离并清洗出有瑕疵的 JSON 块,按照 OpenAI 的标准格式小心翼翼地放进 `tool_calls` 结构里去等待执行。
|
||||||
|
|
||||||
|
这意味着,即便 AI 被配置了奇怪的回复设定、加粗了强调语言,甚至是犯了标点符号拼写小失误,**只要它输出了可以拼凑成工具指令的 JSON 核心单元,整个中继层就能将其挽救,并把正确的工具结果呈现给模型和用户**。 这不仅修复了缺陷,更极大地增强了工具网关的通用性和鲁棒性。
|
||||||
@@ -16,7 +16,6 @@ internal/js/helpers/stream-tool-sieve.js
|
|||||||
internal/js/helpers/stream-tool-sieve/index.js
|
internal/js/helpers/stream-tool-sieve/index.js
|
||||||
internal/js/helpers/stream-tool-sieve/state.js
|
internal/js/helpers/stream-tool-sieve/state.js
|
||||||
internal/js/helpers/stream-tool-sieve/sieve.js
|
internal/js/helpers/stream-tool-sieve/sieve.js
|
||||||
internal/js/helpers/stream-tool-sieve/incremental.js
|
|
||||||
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
||||||
internal/js/helpers/stream-tool-sieve/parse.js
|
internal/js/helpers/stream-tool-sieve/parse.js
|
||||||
internal/js/helpers/stream-tool-sieve/format.js
|
internal/js/helpers/stream-tool-sieve/format.js
|
||||||
|
|||||||
@@ -105,7 +105,6 @@ internal/js/helpers/stream-tool-sieve.js
|
|||||||
internal/js/helpers/stream-tool-sieve/index.js
|
internal/js/helpers/stream-tool-sieve/index.js
|
||||||
internal/js/helpers/stream-tool-sieve/state.js
|
internal/js/helpers/stream-tool-sieve/state.js
|
||||||
internal/js/helpers/stream-tool-sieve/sieve.js
|
internal/js/helpers/stream-tool-sieve/sieve.js
|
||||||
internal/js/helpers/stream-tool-sieve/incremental.js
|
|
||||||
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
||||||
internal/js/helpers/stream-tool-sieve/parse.js
|
internal/js/helpers/stream-tool-sieve/parse.js
|
||||||
internal/js/helpers/stream-tool-sieve/format.js
|
internal/js/helpers/stream-tool-sieve/format.js
|
||||||
|
|||||||
8
tests/compat/expected/toolcalls_allowlist_empty.json
Normal file
8
tests/compat/expected/toolcalls_allowlist_empty.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"calls": [],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": true,
|
||||||
|
"rejectedToolNames": [
|
||||||
|
"unknown_tool"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
{
|
{
|
||||||
"calls": []
|
"calls": [],
|
||||||
}
|
"sawToolCallSyntax": false,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_function_call_tag.json
Normal file
13
tests/compat/expected/toolcalls_function_call_tag.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_invoke_attr.json
Normal file
13
tests/compat/expected/toolcalls_invoke_attr.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_loose_normalize.json
Normal file
13
tests/compat/expected/toolcalls_loose_normalize.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"calls": [],
|
||||||
|
"sawToolCallSyntax": false,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"calls": [],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_standalone_pure.json
Normal file
13
tests/compat/expected/toolcalls_standalone_pure.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
@@ -1,3 +1,8 @@
|
|||||||
{
|
{
|
||||||
"calls": []
|
"calls": [],
|
||||||
}
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": true,
|
||||||
|
"rejectedToolNames": [
|
||||||
|
"unknown_tool"
|
||||||
|
]
|
||||||
|
}
|
||||||
13
tests/compat/expected/toolcalls_xml_tool_call.json
Normal file
13
tests/compat/expected/toolcalls_xml_tool_call.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {
|
||||||
|
"path": "README.MD"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sawToolCallSyntax": true,
|
||||||
|
"rejectedByPolicy": false,
|
||||||
|
"rejectedToolNames": []
|
||||||
|
}
|
||||||
4
tests/compat/fixtures/toolcalls/allowlist_empty.json
Normal file
4
tests/compat/fixtures/toolcalls/allowlist_empty.json
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"text": "{\"tool_calls\":[{\"name\":\"unknown_tool\",\"input\":{\"x\":1}}]}",
|
||||||
|
"tool_names": []
|
||||||
|
}
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"text": "{\"tool_calls\":[{\"name\":\"Read_File\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||||
|
"tool_names": ["read_file"]
|
||||||
|
}
|
||||||
6
tests/compat/fixtures/toolcalls/function_call_tag.json
Normal file
6
tests/compat/fixtures/toolcalls/function_call_tag.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"text": "<function_call><function>read_file</function><parameters>{\"path\":\"README.MD\"}</parameters></function_call>",
|
||||||
|
"tool_names": [
|
||||||
|
"read_file"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
tests/compat/fixtures/toolcalls/invoke_attr.json
Normal file
6
tests/compat/fixtures/toolcalls/invoke_attr.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"text": "<invoke name=\"read_file\"><argument>{\"path\":\"README.MD\"}</argument></invoke>",
|
||||||
|
"tool_names": [
|
||||||
|
"read_file"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
tests/compat/fixtures/toolcalls/loose_normalize.json
Normal file
6
tests/compat/fixtures/toolcalls/loose_normalize.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"text": "{\"tool_calls\":[{\"name\":\"read-file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||||
|
"tool_names": [
|
||||||
|
"read_file"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"text": "{\"tool_calls\":[{\"name\":\"company.fs.read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||||
|
"tool_names": [
|
||||||
|
"read_file"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"mode": "standalone",
|
||||||
|
"text": "```json\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}\n```",
|
||||||
|
"tool_names": ["read_file"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"mode": "standalone",
|
||||||
|
"text": "下面是示例:{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}请勿执行。",
|
||||||
|
"tool_names": ["read_file"]
|
||||||
|
}
|
||||||
5
tests/compat/fixtures/toolcalls/standalone_pure.json
Normal file
5
tests/compat/fixtures/toolcalls/standalone_pure.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"mode": "standalone",
|
||||||
|
"text": "{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||||
|
"tool_names": ["read_file"]
|
||||||
|
}
|
||||||
6
tests/compat/fixtures/toolcalls/xml_tool_call.json
Normal file
6
tests/compat/fixtures/toolcalls/xml_tool_call.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"text": "<tool_call><name>read_file</name><arguments>{\"path\":\"README.MD\"}</arguments></tool_call>",
|
||||||
|
"tool_names": [
|
||||||
|
"read_file"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -13,8 +13,10 @@ const {
|
|||||||
const {
|
const {
|
||||||
parseChunkForContent,
|
parseChunkForContent,
|
||||||
resolveToolcallPolicy,
|
resolveToolcallPolicy,
|
||||||
|
formatIncrementalToolCallDeltas,
|
||||||
normalizePreparedToolNames,
|
normalizePreparedToolNames,
|
||||||
boolDefaultTrue,
|
boolDefaultTrue,
|
||||||
|
filterIncrementalToolCallDeltasByAllowed,
|
||||||
} = handler.__test;
|
} = handler.__test;
|
||||||
|
|
||||||
test('chat-stream exposes parser test hooks', () => {
|
test('chat-stream exposes parser test hooks', () => {
|
||||||
@@ -56,6 +58,46 @@ test('boolDefaultTrue keeps false only when explicitly false', () => {
|
|||||||
assert.equal(boolDefaultTrue(undefined), true);
|
assert.equal(boolDefaultTrue(undefined), true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('filterIncrementalToolCallDeltasByAllowed blocks unknown name and follow-up args', () => {
|
||||||
|
const seen = new Map();
|
||||||
|
const filtered = filterIncrementalToolCallDeltasByAllowed(
|
||||||
|
[
|
||||||
|
{ index: 0, name: 'not_in_schema' },
|
||||||
|
{ index: 0, arguments: '{"x":1}' },
|
||||||
|
],
|
||||||
|
['read_file'],
|
||||||
|
seen,
|
||||||
|
);
|
||||||
|
assert.deepEqual(filtered, []);
|
||||||
|
assert.equal(seen.get(0), '__blocked__');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('filterIncrementalToolCallDeltasByAllowed keeps allowed name and args', () => {
|
||||||
|
const seen = new Map();
|
||||||
|
const filtered = filterIncrementalToolCallDeltasByAllowed(
|
||||||
|
[
|
||||||
|
{ index: 0, name: 'read_file' },
|
||||||
|
{ index: 0, arguments: '{"path":"README.MD"}' },
|
||||||
|
],
|
||||||
|
['read_file'],
|
||||||
|
seen,
|
||||||
|
);
|
||||||
|
assert.deepEqual(filtered, [
|
||||||
|
{ index: 0, name: 'read_file' },
|
||||||
|
{ index: 0, arguments: '{"path":"README.MD"}' },
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('incremental and final tool formatting share stable id via idStore', () => {
|
||||||
|
const idStore = new Map();
|
||||||
|
const incremental = formatIncrementalToolCallDeltas([{ index: 0, name: 'read_file' }], idStore);
|
||||||
|
const { formatOpenAIStreamToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||||
|
const finalCalls = formatOpenAIStreamToolCalls([{ name: 'read_file', input: { path: 'README.MD' } }], idStore);
|
||||||
|
assert.equal(incremental.length, 1);
|
||||||
|
assert.equal(finalCalls.length, 1);
|
||||||
|
assert.equal(incremental[0].id, finalCalls[0].id);
|
||||||
|
});
|
||||||
|
|
||||||
test('parseChunkForContent keeps split response/content fragments inside response array', () => {
|
test('parseChunkForContent keeps split response/content fragments inside response array', () => {
|
||||||
const chunk = {
|
const chunk = {
|
||||||
p: 'response',
|
p: 'response',
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ const fs = require('node:fs');
|
|||||||
const path = require('node:path');
|
const path = require('node:path');
|
||||||
|
|
||||||
const chatStream = require('../../api/chat-stream.js');
|
const chatStream = require('../../api/chat-stream.js');
|
||||||
const { parseToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
const { parseToolCallsDetailed, parseStandaloneToolCallsDetailed } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||||
|
|
||||||
const { parseChunkForContent, estimateTokens } = chatStream.__test;
|
const { parseChunkForContent, estimateTokens } = chatStream.__test;
|
||||||
|
|
||||||
@@ -41,12 +41,17 @@ test('js compat: toolcall fixtures', () => {
|
|||||||
|
|
||||||
for (const file of files) {
|
for (const file of files) {
|
||||||
const name = file.replace(/\.json$/i, '');
|
const name = file.replace(/\.json$/i, '');
|
||||||
const fixture = readJSON(path.join(fixtureDir, file));
|
const fixture = readJSON(path.join(fixtureDir, file));
|
||||||
const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`));
|
const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`));
|
||||||
const got = parseToolCalls(fixture.text, fixture.tool_names || []);
|
const mode = typeof fixture.mode === 'string' ? fixture.mode.trim().toLowerCase() : '';
|
||||||
assert.deepEqual(got, expected.calls, `${name}: calls mismatch`);
|
const parser = mode === 'standalone' ? parseStandaloneToolCallsDetailed : parseToolCallsDetailed;
|
||||||
}
|
const got = parser(fixture.text, fixture.tool_names || []);
|
||||||
});
|
assert.deepEqual(got.calls, expected.calls, `${name}: calls mismatch`);
|
||||||
|
assert.equal(got.sawToolCallSyntax, expected.sawToolCallSyntax, `${name}: sawToolCallSyntax mismatch`);
|
||||||
|
assert.equal(got.rejectedByPolicy, expected.rejectedByPolicy, `${name}: rejectedByPolicy mismatch`);
|
||||||
|
assert.deepEqual(got.rejectedToolNames, expected.rejectedToolNames, `${name}: rejectedToolNames mismatch`);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
test('js compat: token fixtures', () => {
|
test('js compat: token fixtures', () => {
|
||||||
const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json'));
|
const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json'));
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ const {
|
|||||||
processToolSieveChunk,
|
processToolSieveChunk,
|
||||||
flushToolSieve,
|
flushToolSieve,
|
||||||
parseToolCalls,
|
parseToolCalls,
|
||||||
|
parseToolCallsDetailed,
|
||||||
parseStandaloneToolCalls,
|
parseStandaloneToolCalls,
|
||||||
|
formatOpenAIStreamToolCalls,
|
||||||
} = require('../../internal/js/helpers/stream-tool-sieve.js');
|
} = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||||
|
|
||||||
function runSieve(chunks, toolNames) {
|
function runSieve(chunks, toolNames) {
|
||||||
@@ -60,13 +62,25 @@ test('parseToolCalls drops unknown schema names when toolNames is provided', ()
|
|||||||
assert.equal(calls.length, 0);
|
assert.equal(calls.length, 0);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('parseToolCalls keeps unknown names when toolNames is empty', () => {
|
test('parseToolCalls matches tool name case-insensitively and canonicalizes', () => {
|
||||||
|
const payload = JSON.stringify({
|
||||||
|
tool_calls: [{ name: 'Read_File', input: { path: 'README.MD' } }],
|
||||||
|
});
|
||||||
|
const calls = parseToolCalls(payload, ['read_file']);
|
||||||
|
assert.deepEqual(calls, [{ name: 'read_file', input: { path: 'README.MD' } }]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('parseToolCalls rejects all names when toolNames is empty (Go strict parity)', () => {
|
||||||
const payload = JSON.stringify({
|
const payload = JSON.stringify({
|
||||||
tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }],
|
tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }],
|
||||||
});
|
});
|
||||||
const calls = parseToolCalls(payload, []);
|
const calls = parseToolCalls(payload, []);
|
||||||
assert.equal(calls.length, 1);
|
assert.equal(calls.length, 0);
|
||||||
assert.equal(calls[0].name, 'not_in_schema');
|
|
||||||
|
const detailed = parseToolCallsDetailed(payload, []);
|
||||||
|
assert.equal(detailed.sawToolCallSyntax, true);
|
||||||
|
assert.equal(detailed.rejectedByPolicy, true);
|
||||||
|
assert.deepEqual(detailed.rejectedToolNames, ['not_in_schema']);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('parseToolCalls supports fenced json and function.arguments string payload', () => {
|
test('parseToolCalls supports fenced json and function.arguments string payload', () => {
|
||||||
@@ -80,6 +94,34 @@ test('parseToolCalls supports fenced json and function.arguments string payload'
|
|||||||
assert.equal(calls.length, 0);
|
assert.equal(calls.length, 0);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('parseToolCalls parses text-kv fallback payload', () => {
|
||||||
|
const text = [
|
||||||
|
'[TOOL_CALL_HISTORY]',
|
||||||
|
'function.name: execute_command',
|
||||||
|
'function.arguments: {"command":"cd scripts && python check_syntax.py example.py","cwd":null,"timeout":30}',
|
||||||
|
'[/TOOL_CALL_HISTORY]',
|
||||||
|
'Some other text thinking...',
|
||||||
|
].join('\n');
|
||||||
|
const calls = parseToolCalls(text, ['execute_command']);
|
||||||
|
assert.equal(calls.length, 1);
|
||||||
|
assert.equal(calls[0].name, 'execute_command');
|
||||||
|
assert.equal(calls[0].input.command, 'cd scripts && python check_syntax.py example.py');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('parseToolCalls parses multiple text-kv fallback payloads', () => {
|
||||||
|
const text = [
|
||||||
|
'function.name: read_file',
|
||||||
|
'function.arguments: {"path":"abc.txt"}',
|
||||||
|
'',
|
||||||
|
'function.name: bash',
|
||||||
|
'function.arguments: {"command":"ls"}',
|
||||||
|
].join('\n');
|
||||||
|
const calls = parseToolCalls(text, ['read_file', 'bash']);
|
||||||
|
assert.equal(calls.length, 2);
|
||||||
|
assert.equal(calls[0].name, 'read_file');
|
||||||
|
assert.equal(calls[1].name, 'bash');
|
||||||
|
});
|
||||||
|
|
||||||
test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => {
|
test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => {
|
||||||
const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。';
|
const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。';
|
||||||
const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}';
|
const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}';
|
||||||
@@ -95,7 +137,23 @@ test('parseStandaloneToolCalls ignores fenced code block tool_call examples', ()
|
|||||||
assert.equal(calls.length, 0);
|
assert.equal(calls.length, 0);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => {
|
|
||||||
|
test('sieve emits tool_calls in the same chunk processing tick once payload is complete', () => {
|
||||||
|
const state = createToolSieveState();
|
||||||
|
const first = processToolSieveChunk(state, '{"', ['read_file']);
|
||||||
|
const second = processToolSieveChunk(
|
||||||
|
state,
|
||||||
|
'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}',
|
||||||
|
['read_file'],
|
||||||
|
);
|
||||||
|
const firstCalls = first.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||||
|
const secondCalls = second.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||||
|
assert.equal(firstCalls.length, 0);
|
||||||
|
assert.equal(secondCalls.length, 1);
|
||||||
|
assert.equal(secondCalls[0].name, 'read_file');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('sieve emits tool_calls when late key convergence forms a complete payload', () => {
|
||||||
const events = runSieve(
|
const events = runSieve(
|
||||||
[
|
[
|
||||||
'{"',
|
'{"',
|
||||||
@@ -105,12 +163,11 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con
|
|||||||
['read_file'],
|
['read_file'],
|
||||||
);
|
);
|
||||||
const leakedText = collectText(events);
|
const leakedText = collectText(events);
|
||||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
|
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||||
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0);
|
assert.equal(finalCalls.length, 1);
|
||||||
assert.equal(hasToolCall || hasToolDelta, true);
|
assert.equal(finalCalls[0].name, 'read_file');
|
||||||
assert.equal(leakedText.includes('{'), false);
|
|
||||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
|
||||||
assert.equal(leakedText.includes('后置正文C。'), true);
|
assert.equal(leakedText.includes('后置正文C。'), true);
|
||||||
|
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => {
|
test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => {
|
||||||
@@ -141,6 +198,20 @@ test('sieve flushes incomplete captured tool json as text on stream finalize', (
|
|||||||
assert.equal(leakedText.includes('{'), true);
|
assert.equal(leakedText.includes('{'), true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('sieve still intercepts large tool json payloads over previous capture limit', () => {
|
||||||
|
const large = 'a'.repeat(9000);
|
||||||
|
const payload = `{"tool_calls":[{"name":"read_file","input":{"path":"${large}"}}]}`;
|
||||||
|
const events = runSieve(
|
||||||
|
[payload.slice(0, 3000), payload.slice(3000, 7000), payload.slice(7000)],
|
||||||
|
['read_file'],
|
||||||
|
);
|
||||||
|
const leakedText = collectText(events);
|
||||||
|
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0);
|
||||||
|
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas?.length > 0);
|
||||||
|
assert.equal(hasToolCall || hasToolDelta, true);
|
||||||
|
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||||
|
});
|
||||||
|
|
||||||
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
|
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
|
||||||
const events = runSieve(
|
const events = runSieve(
|
||||||
['你好,', '这是普通文本回复。', '请继续。'],
|
['你好,', '这是普通文本回复。', '请继续。'],
|
||||||
@@ -166,7 +237,7 @@ test('sieve intercepts rejected unknown tool payload (no args) without raw leak'
|
|||||||
assert.equal(leakedText.includes('后置正文G。'), true);
|
assert.equal(leakedText.includes('后置正文G。'), true);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('sieve emits incremental tool_call_deltas for split arguments payload', () => {
|
test('sieve emits final tool_calls for split arguments payload without incremental deltas', () => {
|
||||||
const state = createToolSieveState();
|
const state = createToolSieveState();
|
||||||
const first = processToolSieveChunk(
|
const first = processToolSieveChunk(
|
||||||
state,
|
state,
|
||||||
@@ -181,37 +252,49 @@ test('sieve emits incremental tool_call_deltas for split arguments payload', ()
|
|||||||
const tail = flushToolSieve(state, ['read_file']);
|
const tail = flushToolSieve(state, ['read_file']);
|
||||||
const events = [...first, ...second, ...tail];
|
const events = [...first, ...second, ...tail];
|
||||||
const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas');
|
const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas');
|
||||||
assert.equal(deltaEvents.length > 0, true);
|
assert.equal(deltaEvents.length, 0);
|
||||||
const merged = deltaEvents.flatMap((evt) => evt.deltas || []);
|
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||||
const hasName = merged.some((d) => d.name === 'read_file');
|
assert.equal(finalCalls.length, 1);
|
||||||
const argsJoined = merged
|
assert.equal(finalCalls[0].name, 'read_file');
|
||||||
.map((d) => d.arguments || '')
|
assert.deepEqual(finalCalls[0].input, { path: 'README.MD', mode: 'head' });
|
||||||
.join('');
|
|
||||||
assert.equal(hasName, true);
|
|
||||||
assert.equal(argsJoined.includes('"path":"README.MD"'), true);
|
|
||||||
assert.equal(argsJoined.includes('"mode":"head"'), true);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
test('sieve still intercepts tool call after leading plain text without suffix', () => {
|
test('sieve keeps tool json as text when leading prose exists (strict mode)', () => {
|
||||||
const events = runSieve(
|
const events = runSieve(
|
||||||
['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],
|
['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],
|
||||||
['read_file'],
|
['read_file'],
|
||||||
);
|
);
|
||||||
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
||||||
const leakedText = collectText(events);
|
const leakedText = collectText(events);
|
||||||
assert.equal(hasTool, true);
|
assert.equal(hasTool, false);
|
||||||
assert.equal(leakedText.includes('我将调用工具。'), true);
|
assert.equal(leakedText.includes('我将调用工具。'), true);
|
||||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('sieve intercepts tool call and preserves trailing same-chunk text', () => {
|
test('sieve keeps same-chunk trailing prose payload as text in strict mode', () => {
|
||||||
const events = runSieve(
|
const events = runSieve(
|
||||||
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'],
|
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'],
|
||||||
['read_file'],
|
['read_file'],
|
||||||
);
|
);
|
||||||
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
||||||
const leakedText = collectText(events);
|
const leakedText = collectText(events);
|
||||||
assert.equal(hasTool, true);
|
assert.equal(hasTool, false);
|
||||||
assert.equal(leakedText.includes('然后继续解释。'), true);
|
assert.equal(leakedText.includes('然后继续解释。'), true);
|
||||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('formatOpenAIStreamToolCalls reuses ids with the same idStore', () => {
|
||||||
|
const idStore = new Map();
|
||||||
|
const calls = [{ name: 'read_file', input: { path: 'README.MD' } }];
|
||||||
|
const first = formatOpenAIStreamToolCalls(calls, idStore);
|
||||||
|
const second = formatOpenAIStreamToolCalls(calls, idStore);
|
||||||
|
assert.equal(first.length, 1);
|
||||||
|
assert.equal(second.length, 1);
|
||||||
|
assert.equal(first[0].id, second[0].id);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('parseToolCalls rejects mismatched markup tags', () => {
|
||||||
|
const payload = '<tool_call><name>read_file</function><arguments>{"path":"README.md"}</arguments></tool_call>';
|
||||||
|
const calls = parseToolCalls(payload, ['read_file']);
|
||||||
|
assert.equal(calls.length, 0);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -24,9 +24,8 @@
|
|||||||
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent" />
|
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent" />
|
||||||
<meta name="apple-mobile-web-app-title" content="DS2API" />
|
<meta name="apple-mobile-web-app-title" content="DS2API" />
|
||||||
|
|
||||||
<!-- Favicon - using data URI for orange-yellow gradient icon -->
|
<!-- Favicon -->
|
||||||
<link rel="icon" type="image/svg+xml"
|
<link rel="icon" type="image/svg+xml" href="/ds2api-favicon.svg" />
|
||||||
href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Cdefs%3E%3ClinearGradient id='g' x1='0%25' y1='0%25' x2='100%25' y2='100%25'%3E%3Cstop offset='0%25' stop-color='%23f59e0b'/%3E%3Cstop offset='100%25' stop-color='%23ef4444'/%3E%3C/linearGradient%3E%3C/defs%3E%3Crect rx='20' width='100' height='100' fill='url(%23g)'/%3E%3Ctext x='50' y='68' font-family='Arial,sans-serif' font-size='48' font-weight='bold' fill='white' text-anchor='middle'%3EDS%3C/text%3E%3C/svg%3E" />
|
|
||||||
|
|
||||||
<!-- Fonts -->
|
<!-- Fonts -->
|
||||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||||
|
|||||||
20
webui/public/ds2api-favicon.svg
Normal file
20
webui/public/ds2api-favicon.svg
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" role="img" aria-label="DS2API icon">
|
||||||
|
<defs>
|
||||||
|
<linearGradient id="g" x1="0%" y1="0%" x2="100%" y2="100%">
|
||||||
|
<stop offset="0%" stop-color="#f59e0b" />
|
||||||
|
<stop offset="100%" stop-color="#ef4444" />
|
||||||
|
</linearGradient>
|
||||||
|
</defs>
|
||||||
|
<rect width="100" height="100" rx="20" fill="url(#g)" />
|
||||||
|
<text
|
||||||
|
x="50"
|
||||||
|
y="68"
|
||||||
|
text-anchor="middle"
|
||||||
|
font-family="Arial,sans-serif"
|
||||||
|
font-size="48"
|
||||||
|
font-weight="700"
|
||||||
|
fill="#ffffff"
|
||||||
|
>
|
||||||
|
DS
|
||||||
|
</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 539 B |
Reference in New Issue
Block a user