Compare commits

..

122 Commits

Author SHA1 Message Date
CJACK.
b3eae22cef Merge pull request #111 from CJackHwang/dev
Merge pull request #110 from CJackHwang/codex/align-js-runtime-with-go-runtime-logic

Align Vercel JS stream tool-call delta handling with Go runtime
2026-03-20 10:05:25 +08:00
CJACK.
7af0098d1b Merge pull request #110 from CJackHwang/codex/align-js-runtime-with-go-runtime-logic
Align Vercel JS stream tool-call delta handling with Go runtime
2026-03-20 09:49:08 +08:00
CJACK.
17405be300 shrink vercel stream module under line gate limit 2026-03-20 09:47:22 +08:00
CJACK.
5bc03e5de6 align vercel js stream toolcall delta behavior with go runtime 2026-03-20 09:36:45 +08:00
CJACK.
5a5f93148d Merge pull request #109 from CJackHwang/dev
Merge pull request #108 from CJackHwang/codex/clean-up-unused-files-and-update-documentation-uiip50

docs: refresh deployment/testing guides and remove stale investigation report
2026-03-20 03:12:25 +08:00
CJACK.
32dc5b6099 Merge pull request #108 from CJackHwang/codex/clean-up-unused-files-and-update-documentation-uiip50
docs: refresh deployment/testing guides and remove stale investigation report
2026-03-20 03:08:09 +08:00
CJACK.
7936d4675f Merge pull request #107 from CJackHwang/codex/clean-up-unused-files-and-update-documentation
docs: prune stale files and refresh docs, add .env.example, align READMEs/DEPLOY/CONTRIBUTING
2026-03-20 03:07:21 +08:00
CJACK.
808eafa7c6 docs: refresh deployment/testing guides and prune stale report 2026-03-20 03:05:36 +08:00
CJACK.
bcb8ed6df2 docs: prune stale docs and refresh project documentation 2026-03-20 03:05:22 +08:00
CJACK.
8ec5dcc0cc Merge pull request #106 from CJackHwang/dev
Merge pull request #105 from CJackHwang/codex/fix-issues-found-in-review

Merge pull request #104 from CJackHwang/codex/revert-to-commit-efb484b

Restore tool-call parsing and repair logic; remove accidental split files
2026-03-20 02:53:30 +08:00
CJACK.
88a79f212d Fix path control-char repair on JSON fallback parses 2026-03-20 02:52:27 +08:00
CJACK.
b1f8d6192f Merge pull request #105 from CJackHwang/codex/fix-issues-found-in-review
Merge pull request #104 from CJackHwang/codex/revert-to-commit-efb484b

Restore tool-call parsing and repair logic; remove accidental split files
2026-03-20 02:38:35 +08:00
CJACK.
acfb3b225d Split toolcall input parsing to satisfy line gate 2026-03-20 02:37:23 +08:00
CJACK.
99a6164000 Fix path corruption when parsing tool call JSON strings 2026-03-20 02:31:37 +08:00
CJACK.
e49d9d33e2 Merge pull request #104 from CJackHwang/codex/revert-to-commit-efb484b
Restore tool-call parsing and repair logic; remove accidental split files
2026-03-20 02:17:52 +08:00
CJACK.
184a3d1e4e Sync Node tool-call parsing with aggressive fenced/mixed policy 2026-03-20 02:16:37 +08:00
CJACK.
c4ec14f49a Fix refactor line gate for toolcalls_parse 2026-03-20 02:12:34 +08:00
CJACK.
fb5fc0e885 Default to aggressive tool-call interception in mixed/fenced text 2026-03-20 02:03:46 +08:00
CJACK.
20b603666d Allow standalone parser to detect mixed prose tool JSON 2026-03-20 02:03:32 +08:00
CJACK.
4d549b7102 Revert "Merge branch 'dev' into codex/fix-issues-found-in-review"
This reverts commit 33b0d1d144, reversing
changes made to efb484ba4f.
2026-03-20 01:38:11 +08:00
CJACK.
33b0d1d144 Merge branch 'dev' into codex/fix-issues-found-in-review 2026-03-20 01:23:00 +08:00
CJACK.
41c0f7ce28 Merge pull request #102 from CJackHwang/dev
Merge pull request #99 from CJackHwang/codex/refactor-toolcalls_parse.go-for-line-limits

Codex-generated pull request
2026-03-20 01:18:05 +08:00
CJACK.
efb484ba4f Merge pull request #103 from CJackHwang/codex/fix-threshold-issue-and-audit-pr
fix: unblock PR #101 line gate and improve PoW/token retry handling
2026-03-20 01:16:46 +08:00
CJACK.
145501d4a5 fix(tool-sieve): allow mixed prose + tool json interception 2026-03-20 01:15:32 +08:00
CJACK.
2d5103997b fix(tool-sieve): keep mixed prose tool json in strict text mode 2026-03-20 01:15:15 +08:00
CJACK.
52e7e7aae8 fix: unblock line gate and harden pow token recovery 2026-03-20 00:50:05 +08:00
CJACK.
5b5a4000d7 Merge pull request #99 from CJackHwang/codex/refactor-toolcalls_parse.go-for-line-limits
Codex-generated pull request
2026-03-19 21:06:45 +08:00
CJACK.
2bbf603148 fix: address PR #97 review findings 2026-03-18 00:52:24 +08:00
CJACK.
d14b8a0664 Stabilize tool-call parsing and pass refactor gate 2026-03-18 00:45:28 +08:00
CJACK.
f16e0b579e Merge pull request #92 from valkryhx/main
fix(toolcall): fix deepseek function calling bug and add json repair
2026-03-18 00:15:47 +08:00
CJACK.
43cbc4aac0 Merge pull request #97 from CJackHwang/dev
Merge pull request #96 from CJackHwang/codex/update-ci-line-count-limits-cihke3

ci: ignore test files in line gate and raise frontend limit to 500
2026-03-18 00:15:03 +08:00
huangxun
cf569f4749 docs: add testing documentation for tool call debugging
- Add targeted test commands to TESTING.md for debugging tool call issues
- Add quick test commands reference in README.md
- Document specific test cases for DeepSeek tool call parsing
2026-03-17 16:41:16 +08:00
huangxun
c9c59f2490 refactor(toolcall): enhance tool call extraction with multiple keywords and safety limits
- Add support for multiple keywords: tool_calls, function.name:, [tool_call_history]
- Add OOM protection with search limits in extractToolCallObjects
- Add max scan length limit in extractJSONObject to prevent OOM on unclosed objects
- Update tool_sieve to handle more tool call patterns
- Add loose JSON repair in parseToolCallPayload for better error recovery

This improves DeepSeek tool call parsing robustness.
2026-03-17 16:28:27 +08:00
huangxun
16216cc2ca fix(toolcalls): support nested objects in missing array brackets repair
- Upgrade missingArrayBracketsPattern regex to support single-level nested {} objects
- This fixes DeepSeek's list hallucination where tool call JSON objects contain nested fields like {"input": {"q": "value"}}
- Add comprehensive test cases covering 2-5 nested objects, mixed nested/primitive fields, and real DeepSeek 8-queen output patterns
- Add RepairLooseJSON function to repair unquoted keys and missing array brackets

Fixes: DeepSeek tool call parsing with nested JSON objects
2026-03-17 16:24:16 +08:00
CJACK.
de50fd3954 Merge pull request #96 from CJackHwang/codex/update-ci-line-count-limits-cihke3
ci: ignore test files in line gate and raise frontend limit to 500
2026-03-16 23:16:22 +08:00
CJACK.
7648d5f192 ci: keep entry line cap precedence over frontend cap 2026-03-16 23:06:58 +08:00
CJACK.
d35e5eab25 ci: ignore tests in line gate and raise frontend limit 2026-03-16 22:58:13 +08:00
CJACK.
90610a52ce Merge pull request #93 from latticeon/feature/session-management
feat: 添加会话管理功能
2026-03-16 22:12:00 +08:00
latticeon
f6296d506f fix: 修改批量删除会话方式
- 从逐条单个删除改为官方的批量删除接口
- 单个删除函数保留备用
2026-03-16 16:23:39 +08:00
latticeon
dfea092583 fix: 更新测试 mock 结构体以实现新增的接口方法
会话管理功能新增接口方法后,同步更新测试 mock 结构体:
- mockOpenAIConfig: 添加 AutoDeleteSessions() 方法
- streamStatusDSStub: 添加 DeleteAllSessionsForToken() 方法
- testingDSMock: 添加 DeleteAllSessionsForToken() 和 GetSessionCountForToken() 方法

同时修复 client_session_delete.go 中 fmt.Errorf 使用非常量格式字符串的编译错误,改用 errors.New()
2026-03-16 11:58:07 +08:00
latticeon
af7dc134bb fix: 修复会话管理相关问题并拆分文件
1. 修复无限循环问题
   - DeleteAllSessions/DeleteAllSessionsForToken 添加无进度检测
   - 连续 3 轮删除失败则退出循环
   - DeleteAllSessionsForToken 添加 cursor 推进逻辑

2. 修复字段语义不准确
   - TotalCount 重命名为 FirstPageCount
   - 明确该值仅统计第一页,多页账户需关注 HasMore

3. 修复 defer 执行顺序问题
   - 合并两个 defer,确保先删除会话再释放账号
   - 使用同步删除避免并发截断风险

4. 文件拆分
   - 新建 client_session_delete.go 处理会话删除
   - client_session.go 专注于会话查询
2026-03-16 01:44:21 +08:00
latticeon
2657d37f76 添加会话数量显示与清除功能
添加会话清除功能,增强安全性,避免账号被盗等情况泄露源代码
账号列表点击测试后显示账号的会话数量
设置页添加自动清除开关,每次调用后清除被调用账号的所有会话
2026-03-16 00:50:31 +08:00
huangxun
7318d1f4a8 fix(toolcall): fix deepseek function calling bug and add json repair
- Fix: Expand stream sieve keywords to support function.name: and [TOOL_CALL_HISTORY]

- Fix: Add repairInvalidJSONBackslashes to handle unescaped backslashes in Windows paths

- Sync: Update JS stream sieve to match Go implementation

- Test: Add unit tests for backslash repair and deepseek format parsing

- Tool: Move repair json test tool to tests/repair_json_tool.go
2026-03-13 13:47:40 +08:00
CJACK.
f2674487c7 Merge pull request #90 from CJackHwang/dev
Merge pull request #89 from CJackHwang/codex/review-changes-in-pull-request-#88

Support text-kv `function.name`/`function.arguments` fallback and looser name matching
2026-03-09 21:42:28 +08:00
CJACK.
71cdcb43e8 Merge pull request #89 from CJackHwang/codex/review-changes-in-pull-request-#88
Support text-kv `function.name`/`function.arguments` fallback and looser name matching
2026-03-09 19:21:24 +08:00
CJACK.
9c46c3a874 Merge branch 'dev' into codex/review-changes-in-pull-request-#88 2026-03-09 19:20:32 +08:00
CJACK.
12d5f136d5 fix(toolcall): pass gates and align go/js multi-layer parser 2026-03-09 19:16:28 +08:00
CJACK.
00c37d8d2f Merge pull request #88 from valkryhx/main
update openai function calling 成功率高 是因为chat内容和tool内容分开保存,而ds则混合了
2026-03-09 19:04:41 +08:00
huangxun
0f1985af4a feat(util): 增加对混杂文本中 Tool Call 的 fallback 解析支持
- 引入 parseTextKVToolCalls 解析器以处理混杂文本或带历史记录套壳(如 [TOOL_CALL_HISTORY])输出的函数调用提取。
- 将其作为 JSON 和 XML 的 fallback 解析手段集成到主流程。
- 添加单元测试用例且更新相关语义说明文档。
2026-03-09 15:00:16 +08:00
huangxun
fa8affe1b7 Merge remote-tracking branch 'upstream/main' 2026-03-09 14:29:09 +08:00
CJACK.
c59a0b7799 Merge pull request #87 from CJackHwang/dev
Merge pull request #82 from CJackHwang/codex/linear-mention-cja-10-ds2api-go-runtime-js

Align Go/JS tool-call parsing semantics and expand compat fixtures
2026-03-08 13:21:22 +08:00
CJACK.
bd72b91f27 Merge pull request #82 from CJackHwang/codex/linear-mention-cja-10-ds2api-go-runtime-js
Align Go/JS tool-call parsing semantics and expand compat fixtures
2026-03-08 13:19:09 +08:00
CJACK.
9240f85246 Merge pull request #86 from CJackHwang/codex/fix
fix: parse invoke/tool_call arguments in xml compatibility paths
2026-03-08 13:17:29 +08:00
CJACK.
ea4bd1e483 fix: parse invoke/tool_call arguments in xml compatibility paths 2026-03-08 13:16:12 +08:00
CJACK.
9e0de62707 Merge branch 'dev' into codex/linear-mention-cja-10-ds2api-go-runtime-js 2026-03-08 02:40:35 +08:00
CJACK.
128de290db Merge pull request #85 from CJackHwang/revert-84-codex/fix-code-conflicts-in-pr-#82
Revert "Resolve PR #82 merge conflicts and restore tool-call parsing (invoke/argument and XML arguments)"
2026-03-08 02:38:57 +08:00
CJACK.
286d266723 Revert "Resolve PR #82 merge conflicts and restore tool-call parsing (invoke/argument and XML arguments)" 2026-03-08 02:38:29 +08:00
CJACK.
8aad1005b2 Merge pull request #84 from CJackHwang/codex/fix-code-conflicts-in-pr-#82
Resolve PR #82 merge conflicts and restore tool-call parsing (invoke/argument and XML arguments)
2026-03-08 02:31:21 +08:00
CJACK.
11b2f24fc2 Merge origin/dev into PR branch and resolve toolcall parser conflicts 2026-03-08 02:30:12 +08:00
CJACK.
d1f08cbb89 Merge pull request #83 from CJackHwang/dev
Merge pull request #81 from CJackHwang/codex/linear-mention-cja-8

Drop nameless assistant tool_calls and emit parsed tool_calls atomically in sieve
2026-03-08 01:36:38 +08:00
CJACK.
60e9d707d4 Merge origin/dev into PR branch and resolve toolcall test conflicts 2026-03-08 01:10:53 +08:00
CJACK.
9b93badb57 Harden markup tag parsing to avoid mismatched-tag false positives 2026-03-08 00:55:32 +08:00
CJACK.
892213071a Align Go/JS tool-call parsing semantics and compat fixtures 2026-03-08 00:12:43 +08:00
CJACK.
5484d6e59d Merge pull request #81 from CJackHwang/codex/linear-mention-cja-8
Drop nameless assistant tool_calls and emit parsed tool_calls atomically in sieve
2026-03-07 23:15:54 +08:00
CJACK.
0ce3fd22a7 Address PR review: fenced-stream guard and multi ANTML calls 2026-03-07 17:45:43 +08:00
CJACK.
25e40cc3a6 Fix quality gate and expand Claude tool-call format compatibility 2026-03-07 17:27:29 +08:00
CJACK.
af68d21095 Improve Claude Code tool-call compatibility across mixed formats 2026-03-07 16:53:05 +08:00
CJACK.
1fafd25e86 add output_text.done event and remove transient stability report 2026-03-07 16:00:53 +08:00
CJACK.
5f8f28a943 add codex and claude-cli ds2api stability test report 2026-03-07 16:00:36 +08:00
CJACK.
94cf1bfcc7 drop nameless assistant tool history entries 2026-03-07 14:45:10 +08:00
CJACK.
13562cf521 Merge pull request #80 from CJackHwang/dev
Merge pull request #79 from CJackHwang/codex/analyze-and-optimize-issue-#77

fix: 避免 assistant.content=nil 注入 "null" 导致工具历史混杂
2026-03-07 02:13:46 +08:00
valkryhx
d27e700c4f update openai function calling 成功率高 是因为chat内容和tool内容分开保存,而ds则混合了 2026-03-06 23:22:11 +08:00
valkryhx
d6bce5af93 Merge branch 'dev' 2026-03-06 22:49:56 +08:00
CJACK.
75969e710d Merge pull request #79 from CJackHwang/codex/analyze-and-optimize-issue-#77
fix: 避免 assistant.content=nil 注入 "null" 导致工具历史混杂
2026-03-06 22:20:47 +08:00
CJACK.
6c39c8e191 fix: 修复 text 为空时 content 回退丢失问题 2026-03-06 21:24:26 +08:00
CJACK.
0e261ff0a0 refactor: 统一内容归一化逻辑并补充 nil 回归测试 2026-03-06 18:25:27 +08:00
CJACK.
fab326eca1 fix: 修复工具历史注入 null 导致调用格式混乱 2026-03-05 18:20:42 +08:00
CJACK.
c033eceee7 Merge pull request #75 from CJackHwang/dev
Merge pull request #74 from CJackHwang/codex/fix-toolcall-whitelist-issue

Recognize and emit executable tool_calls in mixed prose streams; normalize roles and loosen tool-name matching
2026-03-03 01:30:44 +08:00
CJACK.
a10e03ebe0 Merge pull request #74 from CJackHwang/codex/fix-toolcall-whitelist-issue
Recognize and emit executable tool_calls in mixed prose streams; normalize roles and loosen tool-name matching
2026-03-03 00:40:41 +08:00
CJACK.
a6aa4a1839 补充工具调用行为说明并修正测试文档过时命令 2026-03-03 00:39:02 +08:00
CJACK.
1c749b6803 Merge pull request #73 from CJackHwang/dev
Merge pull request #72 from CJackHwang/codex/review-changes-to-test-account-logic

Normalize mobile login numbers, skip completion flow for session-only account tests, and add tests
2026-03-03 00:07:57 +08:00
CJACK.
c329bf26b6 Merge pull request #72 from CJackHwang/codex/review-changes-to-test-account-logic
Normalize mobile login numbers, skip completion flow for session-only account tests, and add tests
2026-03-02 23:56:27 +08:00
CJACK.
3ae5b57ebe fix(deepseek): normalize mobile before login token refresh 2026-03-02 23:48:54 +08:00
CJACK.
0bf5d5440c Merge pull request #69 from CJackHwang/dev
js对齐
2026-03-01 07:22:42 +08:00
CJACK
d731a1fd4f 门禁 2026-03-01 07:20:24 +08:00
CJACK
93e9fb531d js对齐 2026-03-01 07:15:35 +08:00
CJACK.
6daeb2553d Merge pull request #68 from CJackHwang/dev
修复严重问题
2026-03-01 06:53:23 +08:00
CJACK
321b8a89ee 优化 2026-03-01 06:42:07 +08:00
CJACK
d84875e466 工具调用优化 2026-03-01 06:33:49 +08:00
CJACK
ea8c9a28a9 更新readme和icon 2026-03-01 06:22:41 +08:00
CJACK
a302fb3c25 修复 2026-03-01 05:55:46 +08:00
CJACK.
958bd124cc Merge pull request #64 from CJackHwang/dev
修复已知问题
2026-02-28 18:58:46 +08:00
CJACK.
b89e154e43 Merge pull request #63 from CJackHwang/codex/fix-issues-in-image-analysis
Use repository root Dockerfile, make Go cross-build robust, and fix process wait logic
2026-02-28 18:51:57 +08:00
CJACK.
01924f4a69 fix(docker): auto-detect target arch for local ARM builds 2026-02-28 18:39:33 +08:00
CJACK.
3725694bdf Merge pull request #61 from ronghuaxueleng/main
feat(webui): 账号列表添加搜索过滤功能
2026-02-28 18:16:41 +08:00
root
21b12f583a fix(admin): 账号测试始终发送默认消息以验证完整链路
测试接口不再仅验证会话创建,改为始终发送「你是谁?」
走完整 completion 路径,确保被封禁账号能被正确识别为失败。
2026-02-28 10:18:26 +08:00
root
d97b86e0ee feat(webui): 账号列表添加搜索过滤功能
- 后端 GET /admin/accounts 支持 ?q= 参数,大小写不敏感匹配 identifier/email/mobile
- 前端搜索框内嵌于标题栏按钮行(测试全部按钮前)
- 搜索时重置到第 1 页,分页 total 反映过滤后数量
- 无匹配结果时显示专属提示文案(中英文)
2026-02-28 09:57:19 +08:00
qiangcao
0869ea56cd Merge branch 'CJackHwang:main' into main 2026-02-28 09:18:20 +08:00
CJACK.
4768440627 Merge pull request #60 from CJackHwang/main
同步
2026-02-27 23:18:44 +08:00
CJACK.
9f91da403f Merge pull request #59 from ronghuaxueleng/feature/account-improvements
feat: 账号测试状态持久化、分页选择器、点击账号名复制
2026-02-27 23:16:05 +08:00
CJACK.
89e5ad24b9 Merge pull request #57 from jacob-sheng/feat/zeabur-oneclick
feat(zeabur): 一键部署模板
2026-02-27 23:12:13 +08:00
CJACK.
3f106ac112 Merge pull request #55 from BigUncle/fix/claude-toolcall
fix(claude): 修复工具调用兼容与解析回退
2026-02-27 23:11:46 +08:00
root
f6f6a651fd feat: 账号测试状态持久化、分页选择器、点击账号名复制
- Account 结构加 TestStatus 字段,测试后写入 config.json
- listAccounts 接口返回 test_status,前端根据结果显示红/绿/黄状态点
- 分页选择器支持 10/20/50/100/500/1000/2000/5000
- 点击账号名自动复制到剪贴板,hover 显示复制图标,复制后显示绿色对勾
2026-02-27 21:30:43 +08:00
root
37b867c7ad Merge branch 'docker' 2026-02-27 20:59:16 +08:00
root
25ea28a277 feat: 账号测试状态持久化、分页选择器、点击账号名复制
- Account 结构加 TestStatus 字段,测试后写入 config.json
- listAccounts 接口返回 test_status,前端根据结果显示红/绿/黄状态点
- 分页选择器支持 10/20/50/100/500/1000/2000/5000
- 点击账号名自动复制到剪贴板,hover 显示复制图标,复制后显示绿色对勾
2026-02-27 20:58:18 +08:00
root
0ac49ab32b merge: 合并 main 分支到 docker,保留 docker-compose.yml 和 start.mjs 2026-02-27 20:21:20 +08:00
root
70c59eb71d chore: 将 .claude/ 和 CLAUDE.local.md 从 git 跟踪中排除 2026-02-27 20:19:00 +08:00
AYANGarch
f60a3ea501 docs(readme): add ds2api whale icon 2026-02-26 23:18:57 +08:00
AYANGarch
3f09d60cdc feat(zeabur): add one-click deploy template 2026-02-26 22:54:50 +08:00
BigUncle
d3b5493d2e fix(claude): guard thinking tool-call fallback when final text exists
- only parse tool_calls from thinking when finalText is empty

- apply the same guard in stream runtime finalizer

- add regression tests for non-stream and stream paths
2026-02-26 00:41:39 +08:00
BigUncle
255feb2e65 fix(claude): 修复工具调用兼容与解析回退
- Claude 工具定义兼容 input_schema 与 function.parameters

- tool_calls 解析增加 thinking 回退与大小写无关工具名匹配

- 补充 claude/util 相关回归测试
2026-02-25 18:03:25 +08:00
root
962700f525 chore: 删除无用文件,清理 .gitignore Python 残留规则 2026-02-18 21:06:02 +08:00
root
e143d13ff6 feat: 编译和安装依赖使用国内镜像 2026-02-18 20:57:23 +08:00
root
2f853d7364 feat: 重写 start.mjs 适配 Go 运行时 2026-02-18 20:53:10 +08:00
root
36099a4ada chore: 删除 Python 残留文件(项目已迁移至 Go) 2026-02-18 20:50:07 +08:00
root
73bdb55cee merge: 合并 main 分支到 docker,保留 docker-compose.yml 和分页接口 2026-02-18 20:38:53 +08:00
root
3f3198c959 feat: 账号管理界面优化
- 账号列表支持分页(每页10条,倒序显示)
- API 密钥列表支持展开/关闭
2026-02-07 13:40:14 +08:00
root
6b8f7f8821 feat: 启动脚本显示所有环境变量 2026-02-07 10:55:34 +08:00
root
ac9a1ae742 merge: 合并 main 分支到 docker 2026-02-07 10:28:18 +08:00
root
bd4c2bacbc merge: 合并 main 分支到 docker 2026-02-02 20:31:42 +08:00
root
6cfc7051c4 Merge remote-tracking branch 'origin/main' into docker 2026-02-02 20:29:11 +08:00
root
22a2a97a76 feat: 添加 Docker 和 GitHub Actions 支持
- 添加 docker/Dockerfile 多阶段构建(前端+后端)
- 添加 docker-compose.yml 支持阿里云镜像部署
- 添加 .github/workflows/release.yml 自动发布到阿里云
- 添加 .dockerignore 优化构建
- 添加 VERSION 版本管理文件
- 添加 start.mjs 本地开发启动脚本
2026-02-02 20:23:33 +08:00
146 changed files with 6013 additions and 1322 deletions

View File

@@ -1,93 +1,15 @@
# DS2API environment template (Go runtime) # DS2API runtime
# Copy this file to .env and adjust values.
# Updated: 2026-02
# ---------------------------------------------------------------
# Runtime
# ---------------------------------------------------------------
# HTTP listen port (default: 5001)
PORT=5001 PORT=5001
# Log level: DEBUG | INFO | WARN | ERROR
LOG_LEVEL=INFO LOG_LEVEL=INFO
# Max concurrent inflight requests per account in managed-key mode. # Admin authentication
# Default: 2 DS2API_ADMIN_KEY=change-me
# Recommended client concurrency is calculated dynamically as:
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT
# So by default it is account_count * 2.
# Requests beyond inflight slots enter a waiting queue first.
# Default queue size equals recommended concurrency, so 429 starts after:
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT * 2
# Alias: DS2API_ACCOUNT_CONCURRENCY
# DS2API_ACCOUNT_MAX_INFLIGHT=2
# Optional waiting queue size override for managed-key mode. # Config loading (choose one)
# Default: recommended_concurrency (same as account_count * inflight_limit) # 1) file-based config
# Alias: DS2API_ACCOUNT_QUEUE_SIZE DS2API_CONFIG_PATH=/app/config.json
# DS2API_ACCOUNT_MAX_QUEUE=10 # 2) inline JSON or Base64 JSON
# DS2API_CONFIG_JSON=
# --------------------------------------------------------------- # Optional: static admin assets path
# Admin auth # DS2API_STATIC_ADMIN_DIR=/app/static/admin
# ---------------------------------------------------------------
# Admin key for /admin login and protected admin APIs.
# Default is "admin" when unset, but setting it explicitly is recommended.
DS2API_ADMIN_KEY=admin
# Optional JWT signing secret for admin token.
# Defaults to DS2API_ADMIN_KEY when unset.
# DS2API_JWT_SECRET=change-me
# Optional admin JWT validity in hours (default: 24)
# DS2API_JWT_EXPIRE_HOURS=24
# ---------------------------------------------------------------
# Config source (choose one)
# ---------------------------------------------------------------
# Option A: config file path (local/dev recommended)
# DS2API_CONFIG_PATH=config.json
# Option B: JSON string
# DS2API_CONFIG_JSON={"keys":["your-api-key"],"accounts":[{"email":"user@example.com","password":"xxx","token":""}]}
# Option C: Base64 encoded JSON (recommended for Vercel env var)
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
#
# Generate from local config.json:
# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# ---------------------------------------------------------------
# Paths (optional)
# ---------------------------------------------------------------
# WASM file used for PoW solving
# DS2API_WASM_PATH=sha3_wasm_bg.7b9ca65ddd.wasm
# Built admin static assets directory
# DS2API_STATIC_ADMIN_DIR=static/admin
# Auto-build WebUI on startup when static/admin is missing.
# Default: enabled on local/Docker, disabled on Vercel.
# DS2API_AUTO_BUILD_WEBUI=true
# Internal auth secret used by the Vercel hybrid streaming path
# (Go prepare endpoint <-> Node stream function).
# Optional: falls back to DS2API_ADMIN_KEY when unset.
# DS2API_VERCEL_INTERNAL_SECRET=change-me
# Stream lease TTL seconds for Vercel hybrid streaming.
# During this window, the managed account stays occupied until Node calls release.
# Default: 900 (15 minutes)
# DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS=900
# ---------------------------------------------------------------
# Vercel sync integration (optional)
# ---------------------------------------------------------------
# VERCEL_TOKEN=your-vercel-token
# VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
# VERCEL_TEAM_ID=team_xxxxxxxxxxxx
# Optional: Vercel deployment protection bypass secret.
# If deployment protection is enabled, DS2API will use this value as
# x-vercel-protection-bypass for internal Node->Go calls on Vercel.
# You can also use VERCEL_AUTOMATION_BYPASS_SECRET directly.
# DS2API_VERCEL_PROTECTION_BYPASS=your-bypass-secret

View File

@@ -1,20 +1,20 @@
#### 💻 变更类型 | Change Type #### 💻 变更类型 | Change Type
<!-- For change type, change [ ] to [x]. --> <!-- For change type, change [ ] to [x]. -->
- [ ] ✨ feat - [ ] ✨ feat
- [ ] 🐛 fix - [ ] 🐛 fix
- [ ] ♻️ refactor - [ ] ♻️ refactor
- [ ] 💄 style - [ ] 💄 style
- [ ] 👷 build - [ ] 👷 build
- [ ] ⚡️ perf - [ ] ⚡️ perf
- [ ] 📝 docs - [ ] 📝 docs
- [ ] 🔨 chore - [ ] 🔨 chore
#### 🔀 变更说明 | Description of Change #### 🔀 变更说明 | Description of Change
<!-- Thank you for your Pull Request. Please provide a description above. -->
#### 📝 补充信息 | Additional Information #### 📝 补充信息 | Additional Information
<!-- Add any other context about the Pull Request here. -->

127
.github/workflows/release-dockerhub.yml vendored Normal file
View File

@@ -0,0 +1,127 @@
name: Release to Docker Hub
on:
workflow_dispatch:
inputs:
version_type:
description: '版本类型'
required: true
default: 'patch'
type: choice
options:
- patch
- minor
- major
permissions:
contents: write
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Get current version
id: get_version
run: |
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
TAG_VERSION=${LATEST_TAG#v}
if [ -f VERSION ]; then
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
else
FILE_VERSION="0.0.0"
fi
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
VERSION="$FILE_VERSION"
else
VERSION="$TAG_VERSION"
fi
echo "Current version: $VERSION"
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
- name: Calculate next version
id: next_version
env:
VERSION_TYPE: ${{ github.event.inputs.version_type }}
run: |
VERSION="${{ steps.get_version.outputs.current_version }}"
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
MAJOR="${version_parts[0]:-0}"
MINOR="${version_parts[1]:-0}"
PATCH="${version_parts[2]:-0}"
case "$VERSION_TYPE" in
major)
NEW_VERSION="$((MAJOR + 1)).0.0"
;;
minor)
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
;;
*)
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
;;
esac
echo "New version: $NEW_VERSION"
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
- name: Update VERSION file
run: |
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
- name: Commit VERSION and create tag
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add VERSION
if ! git diff --cached --quiet; then
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
fi
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
git push origin HEAD:main "$NEW_TAG"
# Docker 构建并推送到 Docker Hub
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ secrets.DOCKERHUB_USERNAME }}/ds2api:${{ steps.next_version.outputs.new_tag }}
${{ secrets.DOCKERHUB_USERNAME }}/ds2api:${{ steps.next_version.outputs.new_version }}
${{ secrets.DOCKERHUB_USERNAME }}/ds2api:latest
labels: |
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
org.opencontainers.image.revision=${{ github.sha }}
cache-from: type=gha
cache-to: type=gha,mode=max

128
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,128 @@
name: Release to Aliyun CR
on:
workflow_dispatch:
inputs:
version_type:
description: '版本类型'
required: true
default: 'patch'
type: choice
options:
- patch
- minor
- major
permissions:
contents: write
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Get current version
id: get_version
run: |
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
TAG_VERSION=${LATEST_TAG#v}
if [ -f VERSION ]; then
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
else
FILE_VERSION="0.0.0"
fi
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
VERSION="$FILE_VERSION"
else
VERSION="$TAG_VERSION"
fi
echo "Current version: $VERSION"
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
- name: Calculate next version
id: next_version
env:
VERSION_TYPE: ${{ github.event.inputs.version_type }}
run: |
VERSION="${{ steps.get_version.outputs.current_version }}"
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
MAJOR="${version_parts[0]:-0}"
MINOR="${version_parts[1]:-0}"
PATCH="${version_parts[2]:-0}"
case "$VERSION_TYPE" in
major)
NEW_VERSION="$((MAJOR + 1)).0.0"
;;
minor)
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
;;
*)
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
;;
esac
echo "New version: $NEW_VERSION"
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
- name: Update VERSION file
run: |
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
- name: Commit VERSION and create tag
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add VERSION
if ! git diff --cached --quiet; then
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
fi
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
git push origin HEAD:main "$NEW_TAG"
# Docker 构建并推送到阿里云
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Aliyun Container Registry
uses: docker/login-action@v3
with:
registry: ${{ secrets.ALIYUN_REGISTRY }}
username: ${{ secrets.ALIYUN_REGISTRY_USER }}
password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_tag }}
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_version }}
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:latest
labels: |
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
org.opencontainers.image.revision=${{ github.sha }}
cache-from: type=gha
cache-to: type=gha,mode=max

40
.gitignore vendored
View File

@@ -2,37 +2,6 @@
config.json config.json
.env .env
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Virtual environments
venv/
ENV/
env/
.venv
# IDE # IDE
.vscode/ .vscode/
.idea/ .idea/
@@ -44,7 +13,6 @@ env/
# Logs # Logs
*.log *.log
logs/ logs/
uvicorn.log
artifacts/ artifacts/
# Vercel # Vercel
@@ -56,8 +24,6 @@ webui/node_modules/
webui/dist/ webui/dist/
.npm .npm
.pnpm-store/ .pnpm-store/
# 保留 webui/package-lock.json 用于 CI 缓存
# package-lock.json # 如果有根目录的可以忽略
yarn.lock yarn.lock
pnpm-lock.yaml pnpm-lock.yaml
@@ -86,7 +52,9 @@ coverage*.out
cover/ cover/
# Misc # Misc
*.pyc
*.pyo
.git/ .git/
Thumbs.db Thumbs.db
# Claude Code
.claude/
CLAUDE.local.md

7
API.md
View File

@@ -284,6 +284,11 @@ data: [DONE]
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content` **流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`
补充说明:
- **非代码块上下文**下,工具 JSON 即使与普通文本混合,也会按特征识别并产出可执行 tool call前后普通文本仍可透传
- Markdown fenced code block例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
--- ---
### `GET /v1/models/{id}` ### `GET /v1/models/{id}`
@@ -301,7 +306,7 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
| `messages` | array | ❌ | 与 `input` 二选一 | | `messages` | array | ❌ | 与 `input` 二选一 |
| `instructions` | string | ❌ | 自动前置为 system 消息 | | `instructions` | string | ❌ | 自动前置为 system 消息 |
| `stream` | boolean | ❌ | 默认 `false` | | `stream` | boolean | ❌ | 默认 `false` |
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 | | `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略(含代码块示例豁免) |
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}` | | `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}` |
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。 **非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。

View File

@@ -99,7 +99,7 @@ ds2api/
├── api/ ├── api/
│ ├── index.go # Vercel Serverless Go entry │ ├── index.go # Vercel Serverless Go entry
│ ├── chat-stream.js # Vercel Node.js stream relay │ ├── chat-stream.js # Vercel Node.js stream relay
│ └── helpers/ # Node.js helper modules │ └── (rewrite targets in vercel.json)
├── internal/ ├── internal/
│ ├── account/ # Account pool and concurrency queue │ ├── account/ # Account pool and concurrency queue
│ ├── adapter/ │ ├── adapter/
@@ -112,6 +112,7 @@ ds2api/
│ ├── compat/ # Compatibility helpers │ ├── compat/ # Compatibility helpers
│ ├── config/ # Config loading and hot-reload │ ├── config/ # Config loading and hot-reload
│ ├── deepseek/ # DeepSeek client, PoW WASM │ ├── deepseek/ # DeepSeek client, PoW WASM
│ ├── js/ # Node runtime stream/compat logic
│ ├── devcapture/ # Dev packet capture │ ├── devcapture/ # Dev packet capture
│ ├── format/ # Output formatting │ ├── format/ # Output formatting
│ ├── prompt/ # Prompt building │ ├── prompt/ # Prompt building
@@ -123,7 +124,9 @@ ds2api/
│ └── webui/ # WebUI static hosting │ └── webui/ # WebUI static hosting
├── webui/ # React WebUI source ├── webui/ # React WebUI source
│ └── src/ │ └── src/
│ ├── components/ # Components │ ├── app/ # Routing, auth, config state
│ ├── features/ # Feature modules
│ ├── components/ # Shared components
│ └── locales/ # Language packs │ └── locales/ # Language packs
├── scripts/ # Build and test scripts ├── scripts/ # Build and test scripts
├── static/admin/ # WebUI build output (not committed) ├── static/admin/ # WebUI build output (not committed)

View File

@@ -99,7 +99,7 @@ ds2api/
├── api/ ├── api/
│ ├── index.go # Vercel Serverless Go 入口 │ ├── index.go # Vercel Serverless Go 入口
│ ├── chat-stream.js # Vercel Node.js 流式转发 │ ├── chat-stream.js # Vercel Node.js 流式转发
│ └── helpers/ # Node.js 辅助模块 │ └── (rewrite targets in vercel.json)
├── internal/ ├── internal/
│ ├── account/ # 账号池与并发队列 │ ├── account/ # 账号池与并发队列
│ ├── adapter/ │ ├── adapter/
@@ -112,6 +112,7 @@ ds2api/
│ ├── compat/ # 兼容性辅助 │ ├── compat/ # 兼容性辅助
│ ├── config/ # 配置加载与热更新 │ ├── config/ # 配置加载与热更新
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM │ ├── deepseek/ # DeepSeek 客户端、PoW WASM
│ ├── js/ # Node 运行时流式/兼容逻辑
│ ├── devcapture/ # 开发抓包 │ ├── devcapture/ # 开发抓包
│ ├── format/ # 输出格式化 │ ├── format/ # 输出格式化
│ ├── prompt/ # Prompt 构建 │ ├── prompt/ # Prompt 构建
@@ -123,7 +124,9 @@ ds2api/
│ └── webui/ # WebUI 静态托管 │ └── webui/ # WebUI 静态托管
├── webui/ # React WebUI 源码 ├── webui/ # React WebUI 源码
│ └── src/ │ └── src/
│ ├── components/ # 组件 │ ├── app/ # 路由、鉴权、配置状态
│ ├── features/ # 业务功能模块
│ ├── components/ # 通用组件
│ └── locales/ # 语言包 │ └── locales/ # 语言包
├── scripts/ # 构建与测试脚本 ├── scripts/ # 构建与测试脚本
├── static/admin/ # WebUI 构建产物(不提交) ├── static/admin/ # WebUI 构建产物(不提交)

View File

@@ -113,12 +113,8 @@ go build -o ds2api ./cmd/ds2api
# Copy env template # Copy env template
cp .env.example .env cp .env.example .env
# Generate single-line Base64 from config.json # Edit .env and set at least:
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# Edit .env and set:
# DS2API_ADMIN_KEY=your-admin-key # DS2API_ADMIN_KEY=your-admin-key
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# Start # Start
docker-compose up -d docker-compose up -d
@@ -175,6 +171,18 @@ If container logs look normal but the admin panel is unreachable, check these fi
1. **Port alignment**: when `PORT` is not `5001`, use the same port in your URL (for example `http://localhost:8080/admin`). 1. **Port alignment**: when `PORT` is not `5001`, use the same port in your URL (for example `http://localhost:8080/admin`).
2. **WebUI assets in dev compose**: `docker-compose.dev.yml` runs `go run` in a dev image and does not auto-install Node.js inside the container; if `static/admin` is missing in your repo, `/admin` will return 404. Build once on host: `./scripts/build-webui.sh`. 2. **WebUI assets in dev compose**: `docker-compose.dev.yml` runs `go run` in a dev image and does not auto-install Node.js inside the container; if `static/admin` is missing in your repo, `/admin` will return 404. Build once on host: `./scripts/build-webui.sh`.
### 2.7 Zeabur One-Click (Dockerfile)
This repo includes a `zeabur.yaml` template for one-click deployment on Zeabur:
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/L4CFHP)
Notes:
- **Port**: DS2API listens on `5001` by default; the template sets `PORT=5001`.
- **Persistent config**: the template mounts `/data` and sets `DS2API_CONFIG_PATH=/data/config.json`. After importing config in Admin UI, it will be written and persisted to this path.
- **First login**: after deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions (recommended: rotate to a strong secret after first login).
--- ---
## 3. Vercel Deployment ## 3. Vercel Deployment
@@ -354,7 +362,7 @@ Each archive includes:
- `ds2api` executable (`ds2api.exe` on Windows) - `ds2api` executable (`ds2api.exe` on Windows)
- `static/admin/` (built WebUI assets) - `static/admin/` (built WebUI assets)
- `sha3_wasm_bg.7b9ca65ddd.wasm` - `sha3_wasm_bg.7b9ca65ddd.wasm` (optional; binary has embedded fallback)
- `config.example.json`, `.env.example` - `config.example.json`, `.env.example`
- `README.MD`, `README.en.md`, `LICENSE` - `README.MD`, `README.en.md`, `LICENSE`
@@ -443,7 +451,9 @@ server {
```bash ```bash
# Copy compiled binary and related files to target directory # Copy compiled binary and related files to target directory
sudo mkdir -p /opt/ds2api sudo mkdir -p /opt/ds2api
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/ sudo cp ds2api config.json /opt/ds2api/
# Optional: if you want to use an external WASM file (override embedded one)
# sudo cp sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
sudo cp -r static/admin /opt/ds2api/static/admin sudo cp -r static/admin /opt/ds2api/static/admin
``` ```

View File

@@ -113,12 +113,8 @@ go build -o ds2api ./cmd/ds2api
# 复制环境变量模板 # 复制环境变量模板
cp .env.example .env cp .env.example .env
# 从 config.json 生成单行 Base64 # 编辑 .env请改成你的强密码至少设置
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# 编辑 .env请改成你的强密码设置
# DS2API_ADMIN_KEY=your-admin-key # DS2API_ADMIN_KEY=your-admin-key
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# 启动 # 启动
docker-compose up -d docker-compose up -d
@@ -175,6 +171,18 @@ healthcheck:
1. **端口是否一致**`PORT` 改成非 `5001` 时,访问地址也要改成对应端口(如 `http://localhost:8080/admin`)。 1. **端口是否一致**`PORT` 改成非 `5001` 时,访问地址也要改成对应端口(如 `http://localhost:8080/admin`)。
2. **开发 compose 的 WebUI 静态文件**`docker-compose.dev.yml` 使用 `go run` 开发镜像,不会在容器内自动安装 Node.js若仓库里没有 `static/admin``/admin` 会返回 404。可先在宿主机构建一次`./scripts/build-webui.sh` 2. **开发 compose 的 WebUI 静态文件**`docker-compose.dev.yml` 使用 `go run` 开发镜像,不会在容器内自动安装 Node.js若仓库里没有 `static/admin``/admin` 会返回 404。可先在宿主机构建一次`./scripts/build-webui.sh`
### 2.7 Zeabur 一键部署Dockerfile
仓库提供 `zeabur.yaml` 模板,可在 Zeabur 上一键部署:
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/L4CFHP)
部署要点:
- **端口**:服务默认监听 `5001`,模板会固定设置 `PORT=5001`
- **配置持久化**:模板挂载卷 `/data`,并设置 `DS2API_CONFIG_PATH=/data/config.json`;在管理台导入配置后,会写入并持久化到该路径。
- **首次登录**:部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录(建议首次登录后自行更换为强密码)。
--- ---
## 三、Vercel 部署 ## 三、Vercel 部署
@@ -354,7 +362,7 @@ No Output Directory named "public" found after the Build completed.
- `ds2api` 可执行文件Windows 为 `ds2api.exe` - `ds2api` 可执行文件Windows 为 `ds2api.exe`
- `static/admin/`WebUI 构建产物) - `static/admin/`WebUI 构建产物)
- `sha3_wasm_bg.7b9ca65ddd.wasm` - `sha3_wasm_bg.7b9ca65ddd.wasm`(可选;程序内置 embed fallback
- `config.example.json``.env.example` - `config.example.json``.env.example`
- `README.MD``README.en.md``LICENSE` - `README.MD``README.en.md``LICENSE`
@@ -443,7 +451,9 @@ server {
```bash ```bash
# 将编译好的二进制文件和相关文件复制到目标目录 # 将编译好的二进制文件和相关文件复制到目标目录
sudo mkdir -p /opt/ds2api sudo mkdir -p /opt/ds2api
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/ sudo cp ds2api config.json /opt/ds2api/
# 可选:若你希望使用外置 WASM 文件(覆盖内置版本)
# sudo cp sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
sudo cp -r static/admin /opt/ds2api/static/admin sudo cp -r static/admin /opt/ds2api/static/admin
``` ```

View File

@@ -8,12 +8,15 @@ RUN npm run build
FROM golang:1.24 AS go-builder FROM golang:1.24 AS go-builder
WORKDIR /app WORKDIR /app
ARG TARGETOS=linux ARG TARGETOS
ARG TARGETARCH=amd64 ARG TARGETARCH
COPY go.mod go.sum* ./ COPY go.mod go.sum* ./
RUN go mod download RUN go mod download
COPY . . COPY . .
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o /out/ds2api ./cmd/ds2api RUN set -eux; \
GOOS="${TARGETOS:-$(go env GOOS)}"; \
GOARCH="${TARGETARCH:-$(go env GOARCH)}"; \
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" go build -o /out/ds2api ./cmd/ds2api
FROM busybox:1.36.1-musl AS busybox-tools FROM busybox:1.36.1-musl AS busybox-tools

View File

@@ -1,3 +1,7 @@
<p align="center">
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
</p>
# DS2API # DS2API
[![License](https://img.shields.io/github/license/CJackHwang/ds2api.svg)](LICENSE) [![License](https://img.shields.io/github/license/CJackHwang/ds2api.svg)](LICENSE)
@@ -5,6 +9,8 @@
![Forks](https://img.shields.io/github/forks/CJackHwang/ds2api.svg) ![Forks](https://img.shields.io/github/forks/CJackHwang/ds2api.svg)
[![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases) [![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases)
[![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.md) [![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.md)
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/L4CFHP)
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
语言 / Language: [中文](README.MD) | [English](README.en.md) 语言 / Language: [中文](README.MD) | [English](README.en.md)
@@ -100,6 +106,14 @@ flowchart LR
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。 可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。 另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
#### Claude Code 接入避坑(实测)
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`Claude Code 会请求 `/v1/messages?beta=true`。
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key兼容不同客户端校验习惯。
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析JSON/XML/ANTML/invoke的版本。
### Gemini 接口 ### Gemini 接口
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling`functionDeclarations` → `functionCall` 输出)。 Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling`functionDeclarations` → `functionCall` 输出)。
@@ -146,22 +160,24 @@ go run ./cmd/ds2api
# 1. 准备环境变量文件 # 1. 准备环境变量文件
cp .env.example .env cp .env.example .env
# 2. 从 config.json 生成 DS2API_CONFIG_JSON单行 Base64 # 2. 编辑 .env至少设置 DS2API_ADMIN_KEY
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# 3. 编辑 .env设置
# DS2API_ADMIN_KEY=请替换为强密码 # DS2API_ADMIN_KEY=请替换为强密码
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# 4. 启动 # 3. 启动
docker-compose up -d docker-compose up -d
# 5. 查看日志 # 4. 查看日志
docker-compose logs -f docker-compose logs -f
``` ```
更新镜像:`docker-compose up -d --build` 更新镜像:`docker-compose up -d --build`
#### Zeabur 一键部署Dockerfile
1. 点击上方 “Deploy on Zeabur” 按钮,一键部署。
2. 部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录。
3. 在管理台导入/编辑配置(会写入并持久化到 `/data/config.json`)。
### 方式三Vercel 部署 ### 方式三Vercel 部署
1. Fork 仓库到自己的 GitHub 1. Fork 仓库到自己的 GitHub
@@ -377,7 +393,7 @@ ds2api/
├── api/ ├── api/
│ ├── index.go # Vercel Serverless Go 入口 │ ├── index.go # Vercel Serverless Go 入口
│ ├── chat-stream.js # Vercel Node.js 流式转发 │ ├── chat-stream.js # Vercel Node.js 流式转发
│ └── helpers/ # Node.js 辅助模块 │ └── (rewrite targets in vercel.json)
├── internal/ ├── internal/
│ ├── account/ # 账号池与并发队列 │ ├── account/ # 账号池与并发队列
│ ├── adapter/ │ ├── adapter/
@@ -390,6 +406,7 @@ ds2api/
│ ├── compat/ # 兼容性辅助 │ ├── compat/ # 兼容性辅助
│ ├── config/ # 配置加载与热更新 │ ├── config/ # 配置加载与热更新
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM │ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
│ ├── js/ # Node 运行时流式处理与兼容逻辑
│ ├── devcapture/ # 开发抓包模块 │ ├── devcapture/ # 开发抓包模块
│ ├── format/ # 输出格式化 │ ├── format/ # 输出格式化
│ ├── prompt/ # Prompt 构建 │ ├── prompt/ # Prompt 构建
@@ -400,7 +417,9 @@ ds2api/
│ └── webui/ # WebUI 静态文件托管与自动构建 │ └── webui/ # WebUI 静态文件托管与自动构建
├── webui/ # React WebUI 源码Vite + Tailwind ├── webui/ # React WebUI 源码Vite + Tailwind
│ └── src/ │ └── src/
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage │ ├── app/ # 路由、鉴权、配置状态管理
│ ├── features/ # 业务功能模块account/settings/vercel/apiTester
│ ├── components/ # 登录/落地页等通用组件
│ └── locales/ # 中英文语言包zh.json / en.json │ └── locales/ # 中英文语言包zh.json / en.json
├── scripts/ ├── scripts/
│ └── build-webui.sh # WebUI 手动构建脚本 │ └── build-webui.sh # WebUI 手动构建脚本
@@ -456,6 +475,23 @@ go run ./cmd/ds2api-tests \
npm ci --prefix webui && npm run build --prefix webui npm ci --prefix webui && npm run build --prefix webui
``` ```
## 测试
详细测试指南请参阅 [TESTING.md](TESTING.md)。
### 快速测试命令
```bash
# 运行所有单元测试
go test ./...
# 运行 tool calls 相关测试(调试工具调用问题)
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
# 运行端到端测试
./tests/scripts/run-live.sh
```
## Release 自动构建GitHub Actions ## Release 自动构建GitHub Actions
工作流文件:`.github/workflows/release-artifacts.yml` 工作流文件:`.github/workflows/release-artifacts.yml`
@@ -463,7 +499,7 @@ npm ci --prefix webui && npm run build --prefix webui
- **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发) - **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发)
- **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`+ `sha256sums.txt` - **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`+ `sha256sums.txt`
- **容器镜像发布**:仅推送到 GHCR`ghcr.io/cjackhwang/ds2api` - **容器镜像发布**:仅推送到 GHCR`ghcr.io/cjackhwang/ds2api`
- **每个压缩包包含**`ds2api` 可执行文件、`static/admin`、WASM 文件、配置示例、README、LICENSE - **每个压缩包包含**`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback、配置示例、README、LICENSE
## 免责声明 ## 免责声明

View File

@@ -1,3 +1,7 @@
<p align="center">
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
</p>
# DS2API # DS2API
[![License](https://img.shields.io/github/license/CJackHwang/ds2api.svg)](LICENSE) [![License](https://img.shields.io/github/license/CJackHwang/ds2api.svg)](LICENSE)
@@ -5,6 +9,8 @@
![Forks](https://img.shields.io/github/forks/CJackHwang/ds2api.svg) ![Forks](https://img.shields.io/github/forks/CJackHwang/ds2api.svg)
[![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases) [![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases)
[![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.en.md) [![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.en.md)
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/L4CFHP)
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
Language: [中文](README.MD) | [English](README.en.md) Language: [中文](README.MD) | [English](README.en.md)
@@ -100,6 +106,14 @@ flowchart LR
Override mapping via `claude_mapping` or `claude_model_mapping` in config. Override mapping via `claude_mapping` or `claude_model_mapping` in config.
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility. In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
#### Claude Code integration pitfalls (validated)
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
### Gemini Endpoint ### Gemini Endpoint
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations``functionCall` output). The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations``functionCall` output).
@@ -146,22 +160,24 @@ Default URL: `http://localhost:5001`
# 1. Prepare env file # 1. Prepare env file
cp .env.example .env cp .env.example .env
# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64) # 2. Edit .env (at least set DS2API_ADMIN_KEY)
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
# 3. Edit .env and set:
# DS2API_ADMIN_KEY=replace-with-a-strong-secret # DS2API_ADMIN_KEY=replace-with-a-strong-secret
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
# 4. Start # 3. Start
docker-compose up -d docker-compose up -d
# 5. View logs # 4. View logs
docker-compose logs -f docker-compose logs -f
``` ```
Rebuild after updates: `docker-compose up -d --build` Rebuild after updates: `docker-compose up -d --build`
#### Zeabur One-Click (Dockerfile)
1. Click the “Deploy on Zeabur” button above to deploy.
2. After deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions.
3. Import / edit config in Admin UI (it will be written and persisted to `/data/config.json`).
### Option 3: Vercel ### Option 3: Vercel
1. Fork this repo to your GitHub account 1. Fork this repo to your GitHub account
@@ -339,6 +355,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
When `tools` is present in the request, DS2API performs anti-leak handling: When `tools` is present in the request, DS2API performs anti-leak handling:
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored) 1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
- In non-code-block context, tool JSON may still be recognized even when mixed with normal prose; surrounding prose can remain as text output.
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`) 2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls 3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream 4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
@@ -377,7 +394,7 @@ ds2api/
├── api/ ├── api/
│ ├── index.go # Vercel Serverless Go entry │ ├── index.go # Vercel Serverless Go entry
│ ├── chat-stream.js # Vercel Node.js stream relay │ ├── chat-stream.js # Vercel Node.js stream relay
│ └── helpers/ # Node.js helper modules │ └── (rewrite targets in vercel.json)
├── internal/ ├── internal/
│ ├── account/ # Account pool and concurrency queue │ ├── account/ # Account pool and concurrency queue
│ ├── adapter/ │ ├── adapter/
@@ -390,6 +407,7 @@ ds2api/
│ ├── compat/ # Compatibility helpers │ ├── compat/ # Compatibility helpers
│ ├── config/ # Config loading and hot-reload │ ├── config/ # Config loading and hot-reload
│ ├── deepseek/ # DeepSeek API client, PoW WASM │ ├── deepseek/ # DeepSeek API client, PoW WASM
│ ├── js/ # Node runtime stream/compat logic
│ ├── devcapture/ # Dev packet capture module │ ├── devcapture/ # Dev packet capture module
│ ├── format/ # Output formatting │ ├── format/ # Output formatting
│ ├── prompt/ # Prompt construction │ ├── prompt/ # Prompt construction
@@ -400,7 +418,9 @@ ds2api/
│ └── webui/ # WebUI static file serving and auto-build │ └── webui/ # WebUI static file serving and auto-build
├── webui/ # React WebUI source (Vite + Tailwind) ├── webui/ # React WebUI source (Vite + Tailwind)
│ └── src/ │ └── src/
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage │ ├── app/ # Routing, auth, config state
│ ├── features/ # Feature modules (account/settings/vercel/apiTester)
│ ├── components/ # Shared UI pieces (login/landing, etc.)
│ └── locales/ # Language packs (zh.json / en.json) │ └── locales/ # Language packs (zh.json / en.json)
├── scripts/ ├── scripts/
│ └── build-webui.sh # Manual WebUI build script │ └── build-webui.sh # Manual WebUI build script
@@ -463,7 +483,7 @@ Workflow: `.github/workflows/release-artifacts.yml`
- **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds) - **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds)
- **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt` - **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt`
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`) - **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file, config template, README, LICENSE - **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), config template, README, LICENSE
## Disclaimer ## Disclaimer

View File

@@ -51,7 +51,7 @@ DS2API 提供两个层级的测试:
1. **Preflight 检查** 1. **Preflight 检查**
- `go test ./... -count=1`(单元测试) - `go test ./... -count=1`(单元测试)
- `./tests/scripts/check-node-split-syntax.sh`Node 拆分模块语法门禁) - `./tests/scripts/check-node-split-syntax.sh`Node 拆分模块语法门禁)
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`Node 流式拦截 + compat 单测) - `node --test tests/node/stream-tool-sieve.test.js tests/node/chat-stream.test.js tests/node/js_compat_test.js`
- `npm run build --prefix webui`WebUI 构建检查) - `npm run build --prefix webui`WebUI 构建检查)
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程 2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
@@ -173,6 +173,50 @@ rg "<trace_id>" artifacts/testsuite/<run_id>/server.log
go test ./... go test ./...
``` ```
### 运行特定模块的单元测试
```bash
# 运行 tool calls 相关测试(推荐用于调试 tool call 解析问题)
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
# 运行单个测试用例
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/
# 运行 format 相关测试
go test -v ./internal/format/...
# 运行 adapter 相关测试
go test -v ./internal/adapter/openai/...
```
### 调试 Tool Call 问题 | Debugging Tool Call Issues
当遇到 DeepSeek 工具调用解析问题时,可以使用以下方法:
```bash
# 1. 运行 tool calls 相关的所有测试
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
# 2. 查看测试输出中的详细调试信息
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ 2>&1
# 3. 检查具体测试用例的修复效果
# 测试用例位于 internal/util/toolcalls_test.go包含
# - TestParseToolCallsWithDeepSeekHallucination: DeepSeek 典型幻觉输出
# - TestRepairLooseJSONWithNestedObjects: 嵌套对象的方括号修复
# - TestParseToolCallsWithMixedWindowsPaths: Windows 路径处理
```
### 运行 Node.js 测试
```bash
# 运行 Node 测试
node --test tests/node/stream-tool-sieve.test.js
# 或使用脚本
./tests/scripts/run-unit-node.sh
```
### 跑端到端测试(跳过 preflight ### 跑端到端测试(跳过 preflight
```bash ```bash

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.1.0

View File

@@ -1,18 +1,14 @@
services: services:
ds2api: ds2api:
build: . image: ghcr.io/cjackhwang/ds2api:latest
image: ds2api:latest container_name: ds2api
container_name: ds2api restart: always
ports: ports:
- "${PORT:-5001}:${PORT:-5001}" - "6011:5001"
env_file: volumes:
- .env - ./config.json:/app/config.json # 配置文件
environment: - ./.env:/app/.env # 环境变量
- HOST=0.0.0.0 environment:
restart: unless-stopped - TZ=Asia/Shanghai
healthcheck: - LOG_LEVEL=INFO
test: ["CMD", "/usr/local/bin/busybox", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"] - DS2API_ADMIN_KEY=${DS2API_ADMIN_KEY:-ds2api}
interval: 30s
timeout: 10s
retries: 3
start_period: 10s

View File

@@ -0,0 +1,41 @@
# Tool call parsing semantics (Go canonical spec)
This document defines the cross-runtime contract for `ParseToolCallsDetailed` / `parseToolCallsDetailed`.
## Output contract
- `calls`: accepted tool calls with normalized tool names.
- `sawToolCallSyntax`: true when tool-call-like syntax is detected (`tool_calls`, `<tool_call>`, `<function_call>`, `<invoke>`) or a valid call is parsed.
- `rejectedByPolicy`: true when parser extracted call syntax but all calls are rejected by allow-list policy.
- `rejectedToolNames`: de-duplicated rejected tool names in first-seen order.
## Parse pipeline
1. Strip fenced code blocks for non-standalone parsing.
2. Build candidates from:
- full text,
- fenced JSON snippets,
- extracted JSON objects around `tool_calls`,
- first `{` to last `}` object slice.
3. Parse each candidate in order:
- JSON payload parser (`tool_calls`, list, single call object),
- XML/Markup parser (`<tool_call>`, `<function_call>`, `<invoke>`; supports attributes + nested fields),
- Text KV fallback parser (`function.name: <name>` ... `function.arguments: {json}`).
4. Stop at first candidate that yields at least one call.
## Name normalization policy
When matching parsed names against configured tools:
1. exact match,
2. case-insensitive match,
3. namespace tail match (`a.b.c` => `c`),
4. loose alnum match (remove non `[a-z0-9]`, compare).
## Standalone mode
Standalone mode (`ParseStandaloneToolCallsDetailed`) parses the whole input directly (no candidate slicing), while still applying:
- example-context guard,
- JSON then markup fallback,
- the same allow-list normalization policy.

View File

@@ -183,6 +183,66 @@ func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) {
} }
} }
func TestHandleClaudeStreamRealtimeToolDetectionFromThinkingFallback(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
frames := parseClaudeFrames(t, rec.Body.String())
foundToolUse := false
for _, f := range findClaudeFrames(frames, "content_block_start") {
contentBlock, _ := f.Payload["content_block"].(map[string]any)
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "search" {
foundToolUse = true
break
}
}
if !foundToolUse {
t.Fatalf("expected tool_use block from thinking fallback, body=%s", rec.Body.String())
}
}
func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"normal answer"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
frames := parseClaudeFrames(t, rec.Body.String())
for _, f := range findClaudeFrames(frames, "content_block_start") {
contentBlock, _ := f.Payload["content_block"].(map[string]any)
if contentBlock["type"] == "tool_use" {
t.Fatalf("unexpected tool_use block when final text exists, body=%s", rec.Body.String())
}
}
foundEndTurn := false
for _, f := range findClaudeFrames(frames, "message_delta") {
delta, _ := f.Payload["delta"].(map[string]any)
if delta["stop_reason"] == "end_turn" {
foundEndTurn = true
break
}
}
if !foundEndTurn {
t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String())
}
}
func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) { func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeClaudeSSEHTTPResponse( resp := makeClaudeSSEHTTPResponse(
@@ -255,3 +315,83 @@ func asString(v any) string {
s, _ := v.(string) s, _ := v.(string)
return s return s
} }
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
tests := []struct {
name string
payload string
}{
{name: "xml_tool_call", payload: `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`},
{name: "xml_json_tool_call", payload: `<tool_call>{"tool":"Bash","params":{"command":"pwd"}}</tool_call>`},
{name: "nested_tool_tag_style", payload: `<tool_call><tool name="Bash"><command>pwd</command></tool></tool_call>`},
{name: "function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`},
{name: "antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`},
{name: "antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`},
{name: "invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
frames := parseClaudeFrames(t, rec.Body.String())
foundToolUse := false
for _, f := range findClaudeFrames(frames, "content_block_start") {
contentBlock, _ := f.Payload["content_block"].(map[string]any)
if contentBlock["type"] == "tool_use" {
foundToolUse = true
break
}
}
if !foundToolUse {
t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String())
}
})
}
}
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
frames := parseClaudeFrames(t, rec.Body.String())
foundToolUse := false
for _, f := range findClaudeFrames(frames, "content_block_start") {
contentBlock, _ := f.Payload["content_block"].(map[string]any)
if contentBlock["type"] == "tool_use" {
foundToolUse = true
break
}
}
if !foundToolUse {
t.Fatalf("expected tool_use for fenced example, body=%s", rec.Body.String())
}
foundToolStop := false
for _, f := range findClaudeFrames(frames, "message_delta") {
delta, _ := f.Payload["delta"].(map[string]any)
if delta["stop_reason"] == "tool_use" {
foundToolStop = true
break
}
}
if !foundToolStop {
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
}
}

View File

@@ -125,8 +125,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
if !containsStr(prompt, "Search the web") { if !containsStr(prompt, "Search the web") {
t.Fatalf("expected description in prompt") t.Fatalf("expected description in prompt")
} }
if !containsStr(prompt, "tool_calls") { if !containsStr(prompt, "tool_use") {
t.Fatalf("expected tool_calls instruction in prompt") t.Fatalf("expected tool_use instruction in prompt")
}
if containsStr(prompt, "tool_calls") {
t.Fatalf("expected prompt to avoid tool_calls JSON instruction")
} }
} }
@@ -141,6 +144,34 @@ func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
} }
} }
func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) {
tools := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
"description": "Search via function tool",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"q": map[string]any{"type": "string"},
},
},
},
},
}
prompt := buildClaudeToolPrompt(tools)
if !containsStr(prompt, "Tool: search") {
t.Fatalf("expected OpenAI-style function tool name in prompt, got: %q", prompt)
}
if !containsStr(prompt, "Search via function tool") {
t.Fatalf("expected OpenAI-style function tool description in prompt, got: %q", prompt)
}
if !containsStr(prompt, "\"q\"") {
t.Fatalf("expected parameters schema serialized in prompt, got: %q", prompt)
}
}
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) { func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
tools := []any{"not a map"} tools := []any{"not a map"}
prompt := buildClaudeToolPrompt(tools) prompt := buildClaudeToolPrompt(tools)
@@ -237,6 +268,21 @@ func TestExtractClaudeToolNamesNil(t *testing.T) {
} }
} }
func TestExtractClaudeToolNamesSupportsOpenAIStyleFunctionTool(t *testing.T) {
tools := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
},
},
}
names := extractClaudeToolNames(tools)
if len(names) != 1 || names[0] != "search" {
t.Fatalf("expected [search], got %v", names)
}
}
// ─── toMessageMaps ─────────────────────────────────────────────────── // ─── toMessageMaps ───────────────────────────────────────────────────
func TestToMessageMapsNormal(t *testing.T) { func TestToMessageMapsNormal(t *testing.T) {

View File

@@ -46,13 +46,12 @@ func buildClaudeToolPrompt(tools []any) string {
if !ok { if !ok {
continue continue
} }
name, _ := m["name"].(string) name, desc, schemaObj := extractClaudeToolMeta(m)
desc, _ := m["description"].(string) schema, _ := json.Marshal(schemaObj)
schema, _ := json.Marshal(m["input_schema"])
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
} }
parts = append(parts, parts = append(parts,
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", "When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
) )
@@ -98,13 +97,43 @@ func extractClaudeToolNames(tools []any) []string {
if !ok { if !ok {
continue continue
} }
if name, ok := m["name"].(string); ok && name != "" { name, _, _ := extractClaudeToolMeta(m)
if name != "" {
out = append(out, name) out = append(out, name)
} }
} }
return out return out
} }
func extractClaudeToolMeta(m map[string]any) (string, string, any) {
name, _ := m["name"].(string)
desc, _ := m["description"].(string)
schemaObj := m["input_schema"]
if schemaObj == nil {
schemaObj = m["parameters"]
}
if fn, ok := m["function"].(map[string]any); ok {
if strings.TrimSpace(name) == "" {
name, _ = fn["name"].(string)
}
if strings.TrimSpace(desc) == "" {
desc, _ = fn["description"].(string)
}
if schemaObj == nil {
if v, ok := fn["input_schema"]; ok {
schemaObj = v
}
}
if schemaObj == nil {
if v, ok := fn["parameters"]; ok {
schemaObj = v
}
}
}
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
}
func toMessageMaps(v any) []map[string]any { func toMessageMaps(v any) []map[string]any {
arr, ok := v.([]any) arr, ok := v.([]any)
if !ok { if !ok {

View File

@@ -8,6 +8,7 @@ import (
"ds2api/internal/sse" "ds2api/internal/sse"
streamengine "ds2api/internal/stream" streamengine "ds2api/internal/stream"
"ds2api/internal/util"
) )
type claudeStreamRuntime struct { type claudeStreamRuntime struct {
@@ -116,6 +117,18 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
s.text.WriteString(p.Text) s.text.WriteString(p.Text)
if s.bufferToolContent { if s.bufferToolContent {
if hasUnclosedCodeFence(s.text.String()) {
continue
}
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
if len(detected) > 0 {
s.finalize("tool_use")
return streamengine.ParsedDecision{
ContentSeen: true,
Stop: true,
StopReason: streamengine.StopReason("tool_use_detected"),
}
}
continue continue
} }
s.closeThinkingBlock() s.closeThinkingBlock()
@@ -144,3 +157,7 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
return streamengine.ParsedDecision{ContentSeen: contentSeen} return streamengine.ParsedDecision{ContentSeen: contentSeen}
} }
func hasUnclosedCodeFence(text string) bool {
return strings.Count(text, "```")%2 == 1
}

View File

@@ -46,6 +46,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
if s.bufferToolContent { if s.bufferToolContent {
detected := util.ParseToolCalls(finalText, s.toolNames) detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) == 0 && finalText == "" && finalThinking != "" {
detected = util.ParseToolCalls(finalThinking, s.toolNames)
}
if len(detected) > 0 { if len(detected) > 0 {
stopReason = "tool_use" stopReason = "tool_use"
for i, tc := range detected { for i, tc := range detected {

View File

@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) { func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
upstream := makeGeminiUpstreamResponse( upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`, `data: [DONE]`,
) )
h := &Handler{ h := &Handler{
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
} }
} }
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
r := chi.NewRouter()
RegisterRoutes(r, h)
body := `{
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
}`
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
req.Header.Set("Authorization", "Bearer direct-token")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
candidates, _ := out["candidates"].([]any)
c0, _ := candidates[0].(map[string]any)
content, _ := c0["content"].(map[string]any)
parts, _ := content["parts"].([]any)
part0, _ := parts[0].(map[string]any)
functionCall, _ := part0["functionCall"].(map[string]any)
if functionCall["name"] != "eval_javascript" {
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
}
}
func TestStreamGenerateContentEmitsSSE(t *testing.T) { func TestStreamGenerateContentEmitsSSE(t *testing.T) {
upstream := makeGeminiUpstreamResponse( upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"hello "}`, `data: {"p":"response/content","v":"hello "}`,

View File

@@ -98,11 +98,11 @@ func (s *chatStreamRuntime) sendDone() {
func (s *chatStreamRuntime) finalize(finishReason string) { func (s *chatStreamRuntime) finalize(finishReason string) {
finalThinking := s.thinking.String() finalThinking := s.thinking.String()
finalText := s.text.String() finalText := s.text.String()
detected := util.ParseToolCalls(finalText, s.toolNames) detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
if len(detected) > 0 && !s.toolCallsDoneEmitted { if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls" finishReason = "tool_calls"
delta := map[string]any{ delta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs), "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
} }
if !s.firstChunkSent { if !s.firstChunkSent {
delta["role"] = "assistant" delta["role"] = "assistant"
@@ -158,7 +158,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
} }
} }
if len(detected) > 0 || s.toolCallsEmitted { if len(detected.Calls) > 0 || s.toolCallsEmitted {
finishReason = "tool_calls" finishReason = "tool_calls"
} }
s.sendChunk(openaifmt.BuildChatStreamChunk( s.sendChunk(openaifmt.BuildChatStreamChunk(

View File

@@ -19,6 +19,7 @@ type DeepSeekCaller interface {
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
DeleteAllSessionsForToken(ctx context.Context, token string) error
} }
type ConfigReader interface { type ConfigReader interface {
@@ -28,6 +29,7 @@ type ConfigReader interface {
ToolcallEarlyEmitConfidence() string ToolcallEarlyEmitConfidence() string
ResponsesStoreTTLSeconds() int ResponsesStoreTTLSeconds() int
EmbeddingsProvider() string EmbeddingsProvider() string
AutoDeleteSessions() bool
} }
var _ AuthResolver = (*auth.Resolver)(nil) var _ AuthResolver = (*auth.Resolver)(nil)

View File

@@ -19,6 +19,7 @@ func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMo
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit } func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL } func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv } func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) { func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
cfg := mockOpenAIConfig{ cfg := mockOpenAIConfig{

View File

@@ -35,7 +35,25 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, status, detail) writeOpenAIError(w, status, detail)
return return
} }
defer h.Auth.Release(a) defer func() {
// 自动删除会话(同步)
// 必须在 Release 之前同步删除,否则:
// 1. 异步删除时账号已被 Release
// 2. 新请求可能获取到同一账号并开始使用
// 3. 异步删除仍在进行,会截断新请求正在使用的会话
if h.Store.AutoDeleteSessions() && a.DeepSeekToken != "" {
deleteCtx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
err := h.DS.DeleteAllSessionsForToken(deleteCtx, a.DeepSeekToken)
if err != nil {
config.Logger.Warn("[auto_delete_sessions] failed", "account", a.AccountID, "error", err)
} else {
config.Logger.Debug("[auto_delete_sessions] success", "account", a.AccountID)
}
}
h.Auth.Release(a)
}()
r = r.WithContext(auth.WithAuth(r.Context(), a)) r = r.WithContext(auth.WithAuth(r.Context(), a))
var req map[string]any var req map[string]any

View File

@@ -53,7 +53,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh
if len(toolSchemas) == 0 { if len(toolSchemas) == 0 {
return messages, names return messages, names
} }
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY a JSON code block like this:\n```json\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n```\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n```json\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n```\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON code block. The response must start with ```json and end with ```.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block.\n5) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n6) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets."
if policy.Mode == util.ToolChoiceRequired { if policy.Mode == util.ToolChoiceRequired {
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list." toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
} }

View File

@@ -3,6 +3,7 @@ package openai
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -210,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
} }
} }
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) { func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"下面是示例:"}`, `data: {"p":"response/content","v":"下面是示例:"}`,
@@ -233,15 +234,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
} }
msg, _ := choice["message"].(map[string]any) msg, _ := choice["message"].(map[string]any)
toolCalls, _ := msg["tool_calls"].([]any) toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) == 0 { if len(toolCalls) != 1 {
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"]) t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
} }
if msg["content"] != nil { content, _ := msg["content"].(string)
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"]) if strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
} }
} }
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) { func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}", "data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
@@ -257,16 +259,17 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
out := decodeJSONBody(t, rec.Body.String()) out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any) choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any) choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "stop" { if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
} }
msg, _ := choice["message"].(map[string]any) msg, _ := choice["message"].(map[string]any)
if _, ok := msg["tool_calls"]; ok { toolCalls, _ := msg["tool_calls"].([]any)
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"]) if len(toolCalls) != 1 {
t.Fatalf("expected one tool_call field for fenced example: %#v", msg["tool_calls"])
} }
content, _ := msg["content"].(string) content, _ := msg["content"].(string)
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) { if strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected fenced tool example to pass through as text, got %q", content) t.Fatalf("expected raw tool_calls json stripped from content, got %q", content)
} }
} }
@@ -315,6 +318,36 @@ func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
} }
} }
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
h := &Handler{}
large := strings.Repeat("a", 9000)
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
splitAt := len(payload) / 2
resp := makeSSEHTTPResponse(
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) { func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
@@ -500,15 +533,12 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") { if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
} }
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" { if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String()) t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
} }
} }
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"我将调用工具。"}`, `data: {"p":"response/content","v":"我将调用工具。"}`,
@@ -542,15 +572,13 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
if !strings.Contains(got, "我将调用工具。") { if !strings.Contains(got, "我将调用工具。") {
t.Fatalf("expected leading text to keep streaming, got=%q", got) t.Fatalf("expected leading text to keep streaming, got=%q", got)
} }
if strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("unexpected raw tool json leak, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" { if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
} }
} }
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) { func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`, `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
@@ -583,15 +611,52 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
if !strings.Contains(got, "接下来我会继续说明。") { if !strings.Contains(got, "接下来我会继续说明。") {
t.Fatalf("expected trailing plain text to be preserved, got=%q", got) t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
} }
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例不要执行。"),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if strings.Contains(strings.ToLower(got), "tool_calls") { if strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("unexpected raw tool json leak, got=%q", got) t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
} }
if streamFinishReason(frames) != "tool_calls" { if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
} }
} }
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) { func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
h := &Handler{} h := &Handler{}
spaces := strings.Repeat(" ", 200) spaces := strings.Repeat(" ", 200)
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
@@ -612,9 +677,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
if !streamHasToolCallsDelta(frames) { if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
} }
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
content := strings.Builder{} content := strings.Builder{}
for _, frame := range frames { for _, frame := range frames {
choices, _ := frame["choices"].([]any) choices, _ := frame["choices"].([]any)
@@ -627,9 +689,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
} }
} }
got := content.String() got := content.String()
if strings.Contains(got, "{") {
t.Fatalf("unexpected suspicious prefix leak in content: %q", got)
}
if !strings.Contains(got, "后置正文C。") { if !strings.Contains(got, "后置正文C。") {
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got) t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
} }
@@ -712,7 +771,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
} }
} }
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
h := &Handler{} h := &Handler{}
resp := makeSSEHTTPResponse( resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`, `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
@@ -735,8 +794,8 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
} }
argChunks := streamToolCallArgumentChunks(frames) argChunks := streamToolCallArgumentChunks(frames)
if len(argChunks) < 2 { if len(argChunks) == 0 {
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String()) t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
} }
joined := strings.Join(argChunks, "") joined := strings.Join(argChunks, "")
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) { if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {

View File

@@ -3,10 +3,10 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"strings" "strings"
"ds2api/internal/config" "ds2api/internal/config"
"ds2api/internal/prompt"
) )
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any { func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
@@ -34,9 +34,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
"role": "user", "role": "user",
"content": formatToolResultForPrompt(msg), "content": formatToolResultForPrompt(msg),
}) })
case "user", "system": case "user", "system", "developer":
out = append(out, map[string]any{ out = append(out, map[string]any{
"role": role, "role": normalizeOpenAIRoleForPrompt(role),
"content": normalizeOpenAIContentForPrompt(msg["content"]), "content": normalizeOpenAIContentForPrompt(msg["content"]),
}) })
default: default:
@@ -48,7 +48,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
role = "user" role = "user"
} }
out = append(out, map[string]any{ out = append(out, map[string]any{
"role": role, "role": normalizeOpenAIRoleForPrompt(role),
"content": content, "content": content,
}) })
} }
@@ -78,7 +78,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) strin
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
} }
if name == "" { if name == "" {
name = "unknown" continue
} }
if args == "" { if args == "" {
args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
@@ -133,32 +133,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
} }
func normalizeOpenAIContentForPrompt(v any) string { func normalizeOpenAIContentForPrompt(v any) string {
switch x := v.(type) { return prompt.NormalizeContent(v)
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, item := range x {
m, ok := item.(map[string]any)
if !ok {
continue
}
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
if t != "text" && t != "output_text" && t != "input_text" {
continue
}
if text := asString(m["text"]); text != "" {
parts = append(parts, text)
continue
}
if text := asString(m["content"]); text != "" {
parts = append(parts, text)
}
}
return strings.Join(parts, "\n")
default:
return marshalToPromptString(v)
}
} }
func normalizeOpenAIArgumentsForPrompt(v any) string { func normalizeOpenAIArgumentsForPrompt(v any) string {
@@ -175,30 +150,11 @@ func normalizeToolArgumentString(raw string) string {
if trimmed == "" { if trimmed == "" {
return "" return ""
} }
if !looksLikeConcatenatedJSON(trimmed) { if looksLikeConcatenatedJSON(trimmed) {
return trimmed // Keep original payload to avoid silent argument rewrites.
return raw
} }
dec := json.NewDecoder(strings.NewReader(trimmed)) return trimmed
values := make([]any, 0, 2)
for {
var v any
if err := dec.Decode(&v); err != nil {
if err == io.EOF {
break
}
return trimmed
}
values = append(values, v)
}
if len(values) < 2 {
return trimmed
}
last := values[len(values)-1]
b, err := json.Marshal(last)
if err != nil || len(b) == 0 {
return trimmed
}
return string(b)
} }
func marshalToPromptString(v any) string { func marshalToPromptString(v any) string {
@@ -209,6 +165,14 @@ func marshalToPromptString(v any) string {
return string(b) return string(b)
} }
func normalizeOpenAIRoleForPrompt(role string) string {
role = strings.ToLower(strings.TrimSpace(role))
if role == "developer" {
return "system"
}
return role
}
func asString(v any) string { func asString(v any) string {
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
return s return s

View File

@@ -168,7 +168,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
} }
} }
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) { func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) {
raw := []any{ raw := []any{
map[string]any{ map[string]any{
"role": "assistant", "role": "assistant",
@@ -189,10 +189,94 @@ func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *te
t.Fatalf("expected one normalized message, got %d", len(normalized)) t.Fatalf("expected one normalized message, got %d", len(normalized))
} }
content, _ := normalized[0]["content"].(string) content, _ := normalized[0]["content"].(string)
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) { if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
t.Fatalf("expected repaired arguments in tool history, got %q", content) t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
} }
if strings.Contains(content, `{}{"query":"测试工具调用"}`) { }
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
raw := []any{
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_missing_name",
"type": "function",
"function": map[string]any{
"arguments": `{"path":"README.MD"}`,
},
},
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 0 {
t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized)
}
}
func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) {
raw := []any{
map[string]any{
"role": "assistant",
"content": nil,
"tool_calls": []any{
map[string]any{
"id": "call_screenshot",
"function": map[string]any{
"name": "send_file_to_user",
"arguments": `{"file_path":"/tmp/a.png"}`,
},
},
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized message, got %d", len(normalized))
}
content, _ := normalized[0]["content"].(string)
if strings.Contains(content, "<Assistant>null") || strings.HasPrefix(strings.TrimSpace(content), "null") {
t.Fatalf("unexpected null literal injected into assistant tool history: %q", content)
}
if !strings.Contains(content, "function.name: send_file_to_user") {
t.Fatalf("expected tool history block preserved, got %q", content)
}
}
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
raw := []any{
map[string]any{"role": "developer", "content": "必须先走工具调用"},
map[string]any{"role": "user", "content": "你好"},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 2 {
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
}
if normalized[0]["role"] != "system" {
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
}
}
func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextEmpty(t *testing.T) {
raw := []any{
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "", "content": "工具说明文本"},
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized message, got %d", len(normalized))
}
content, _ := normalized[0]["content"].(string)
if content != "工具说明文本" {
t.Fatalf("expected content fallback text preserved, got %q", content)
} }
} }

View File

@@ -135,7 +135,7 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
} }
} }
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) { func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) {
msgs := normalizeResponsesInputAsMessages([]any{ msgs := normalizeResponsesInputAsMessages([]any{
map[string]any{ map[string]any{
"type": "function_call", "type": "function_call",
@@ -151,8 +151,8 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArg
toolCalls, _ := m["tool_calls"].([]any) toolCalls, _ := m["tool_calls"].([]any)
call, _ := toolCalls[0].(map[string]any) call, _ := toolCalls[0].(map[string]any)
fn, _ := call["function"].(map[string]any) fn, _ := call["function"].(map[string]any)
if fn["arguments"] != `{"q":"golang"}` { if fn["arguments"] != `{}{"q":"golang"}` {
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"]) t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"])
} }
} }

View File

@@ -113,15 +113,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
return return
} }
result := sse.CollectStream(resp, thinkingEnabled, true) result := sse.CollectStream(resp, thinkingEnabled, true)
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames) textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames)
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text") logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
callCount := len(textParsed.Calls) callCount := len(textParsed.Calls)
if callCount == 0 {
callCount = len(thinkingParsed.Calls)
}
if toolChoice.IsRequired() && callCount == 0 { if toolChoice.IsRequired() && callCount == 0 {
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
return return

View File

@@ -29,7 +29,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
return nil return nil
} }
return map[string]any{ return map[string]any{
"role": role, "role": normalizeOpenAIRoleForPrompt(role),
"content": content, "content": content,
} }
} }
@@ -51,7 +51,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
role = "user" role = "user"
} }
return map[string]any{ return map[string]any{
"role": role, "role": normalizeOpenAIRoleForPrompt(role),
"content": content, "content": content,
} }
case "function_call_output", "tool_result": case "function_call_output", "tool_result":

View File

@@ -102,16 +102,11 @@ func (s *responsesStreamRuntime) finalize() {
if s.bufferToolContent { if s.bufferToolContent {
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
} }
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames) textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
detected := textParsed.Calls detected := textParsed.Calls
if len(detected) == 0 { s.logToolPolicyRejections(textParsed)
detected = thinkingParsed.Calls
}
s.logToolPolicyRejections(textParsed, thinkingParsed)
if len(detected) > 0 { if len(detected) > 0 {
s.toolCallsEmitted = true s.toolCallsEmitted = true
@@ -157,7 +152,7 @@ func (s *responsesStreamRuntime) finalize() {
s.sendDone() s.sendDone()
} }
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) { func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) {
logRejected := func(parsed util.ToolCallParseResult, channel string) { logRejected := func(parsed util.ToolCallParseResult, channel string) {
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
if !parsed.RejectedByPolicy || len(rejected) == 0 { if !parsed.RejectedByPolicy || len(rejected) == 0 {
@@ -172,7 +167,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar
) )
} }
logRejected(textParsed, "text") logRejected(textParsed, "text")
logRejected(thinkingParsed, "thinking")
} }
func (s *responsesStreamRuntime) hasFunctionCallDone() bool { func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
@@ -207,9 +201,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
} }
s.thinking.WriteString(p.Text) s.thinking.WriteString(p.Text)
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
if s.bufferToolContent {
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
}
continue continue
} }

View File

@@ -94,6 +94,16 @@ func (s *responsesStreamRuntime) closeMessageItem() {
outputIndex := s.ensureMessageOutputIndex() outputIndex := s.ensureMessageOutputIndex()
text := s.visibleText.String() text := s.visibleText.String()
if s.messagePartAdded { if s.messagePartAdded {
s.sendEvent(
"response.output_text.done",
openaifmt.BuildResponsesTextDonePayload(
s.responseID,
itemID,
outputIndex,
0,
text,
),
)
s.sendEvent( s.sendEvent(
"response.content_part.done", "response.content_part.done",
openaifmt.BuildResponsesContentPartDonePayload( openaifmt.BuildResponsesContentPartDonePayload(

View File

@@ -99,9 +99,6 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
if !strings.Contains(body, "event: response.output_item.done") { if !strings.Contains(body, "event: response.output_item.done") {
t.Fatalf("expected response.output_item.done event, body=%s", body) t.Fatalf("expected response.output_item.done event, body=%s", body)
} }
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") { if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body) t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
} }
@@ -229,6 +226,40 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
} }
} }
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("hello") + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_text.done") {
t.Fatalf("expected response.output_text.done payload, body=%s", body)
}
textDoneIdx := strings.Index(body, "event: response.output_text.done")
partDoneIdx := strings.Index(body, "event: response.content_part.done")
if textDoneIdx < 0 || partDoneIdx < 0 {
t.Fatalf("expected output_text.done + content_part.done, body=%s", body)
}
if textDoneIdx > partDoneIdx {
t.Fatalf("expected output_text.done before content_part.done, body=%s", body)
}
}
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) { func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -266,7 +297,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
} }
} }
func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) { func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@@ -291,23 +322,8 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added") addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
if len(addedPayloads) < 2 { if len(addedPayloads) < 1 {
t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String()) t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
}
indexes := map[int]struct{}{}
typeByIndex := map[int]string{}
addedIDs := map[string]string{}
for _, payload := range addedPayloads {
item, _ := payload["item"].(map[string]any)
itemType := strings.TrimSpace(asString(item["type"]))
outputIndex := int(asFloat(payload["output_index"]))
if _, exists := indexes[outputIndex]; exists {
t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload)
}
indexes[outputIndex] = struct{}{}
typeByIndex[outputIndex] = itemType
addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"]))
} }
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
@@ -316,20 +332,25 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
} }
responseObj, _ := completedPayload["response"].(map[string]any) responseObj, _ := completedPayload["response"].(map[string]any)
output, _ := responseObj["output"].([]any) output, _ := responseObj["output"].([]any)
found := map[string]bool{} hasMessage := false
hasFunctionCall := false
for _, item := range output { for _, item := range output {
m, _ := item.(map[string]any) m, _ := item.(map[string]any)
itemType := strings.TrimSpace(asString(m["type"])) if m == nil {
itemID := strings.TrimSpace(asString(m["id"]))
if itemType == "" || itemID == "" {
continue continue
} }
if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID { if asString(m["type"]) == "message" {
found[itemType] = true hasMessage = true
}
if asString(m["type"]) == "function_call" {
hasFunctionCall = true
} }
} }
if !found["message"] || !found["function_call"] { if !hasMessage {
t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output) t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
}
if !hasFunctionCall {
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
} }
} }
@@ -360,7 +381,7 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
} }
} }
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) { func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@@ -373,7 +394,7 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
return "data: " + string(b) + "\n" return "data: " + string(b) + "\n"
} }
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it // invalid JSON (NaN) should remain plain text in strict mode.
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n" streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
resp := &http.Response{ resp := &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
@@ -382,14 +403,11 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String() body := rec.Body.String()
if !strings.Contains(body, "event: response.function_call_arguments.delta") { if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body) t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
} }
if !strings.Contains(body, "event: response.function_call_arguments.done") { if !strings.Contains(body, "event: response.output_text.delta") {
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body) t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
}
if !strings.Contains(body, "event: response.output_item.done") {
t.Fatalf("expected runtime to close function output item, body=%s", body)
} }
if !strings.Contains(body, "event: response.completed") { if !strings.Contains(body, "event: response.completed") {
t.Fatalf("expected response.completed event, body=%s", body) t.Fatalf("expected response.completed event, body=%s", body)
@@ -430,6 +448,42 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
} }
} }
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, value string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": value,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
sseLine("response/content", "plain text only") +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
}
if strings.Contains(body, "event: response.completed") {
t.Fatalf("did not expect response.completed after failure, body=%s", body)
}
}
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) { func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -516,6 +570,33 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
} }
} }
func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
`data: [DONE]` + "\n",
)),
}
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
}
out := decodeJSONBody(t, rec.Body.String())
errObj, _ := out["error"].(map[string]any)
if asString(errObj["code"]) != "tool_choice_violation" {
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
}
}
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
h := &Handler{} h := &Handler{}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

View File

@@ -53,6 +53,10 @@ func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAut
return m.resp, nil return m.resp, nil
} }
func (m streamStatusDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
return nil
}
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response { func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
body := strings.Join(lines, "\n") body := strings.Join(lines, "\n")
if !strings.HasSuffix(body, "\n") { if !strings.HasSuffix(body, "\n") {
@@ -168,18 +172,14 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
} }
outputText, _ := out["output_text"].(string) outputText, _ := out["output_text"].(string)
if outputText != "" { if outputText != "" {
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText) t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
} }
output, _ := out["output"].([]any) output, _ := out["output"].([]any)
hasFunctionCall := false if len(output) != 1 {
for _, item := range output { t.Fatalf("expected one output item, got %#v", output)
m, _ := item.(map[string]any)
if m != nil && m["type"] == "function_call" {
hasFunctionCall = true
break
}
} }
if !hasFunctionCall { first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected function_call output item, got %#v", output) t.Fatalf("expected function_call output item, got %#v", output)
} }
} }

View File

@@ -14,6 +14,11 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
state.pending.WriteString(chunk) state.pending.WriteString(chunk)
} }
events := make([]toolStreamEvent, 0, 2) events := make([]toolStreamEvent, 0, 2)
if len(state.pendingToolCalls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
state.pendingToolRaw = ""
state.pendingToolCalls = nil
}
for { for {
if state.capturing { if state.capturing {
@@ -21,32 +26,30 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
state.capture.WriteString(state.pending.String()) state.capture.WriteString(state.pending.String())
state.pending.Reset() state.pending.Reset()
} }
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
}
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
if !ready { if !ready {
if state.capture.Len() > toolSieveCaptureLimit {
content := state.capture.String()
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
continue
}
break break
} }
captured := state.capture.String()
state.capture.Reset() state.capture.Reset()
state.capturing = false state.capturing = false
state.resetIncrementalToolState() state.resetIncrementalToolState()
if len(calls) > 0 {
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
if suffix != "" {
state.pending.WriteString(suffix)
}
_ = captured
state.pendingToolCalls = calls
continue
}
if prefix != "" { if prefix != "" {
state.noteText(prefix) state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix}) events = append(events, toolStreamEvent{Content: prefix})
} }
if len(calls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: calls})
}
if suffix != "" { if suffix != "" {
state.pending.WriteString(suffix) state.pending.WriteString(suffix)
} }
@@ -89,6 +92,11 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
return nil return nil
} }
events := processToolSieveChunk(state, "", toolNames) events := processToolSieveChunk(state, "", toolNames)
if len(state.pendingToolCalls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
state.pendingToolRaw = ""
state.pendingToolCalls = nil
}
if state.capturing { if state.capturing {
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
if ready { if ready {
@@ -159,22 +167,22 @@ func findToolSegmentStart(s string) int {
return -1 return -1
} }
lower := strings.ToLower(s) lower := strings.ToLower(s)
offset := 0 keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
for { bestKeyIdx := -1
keyRel := strings.Index(lower[offset:], "tool_calls") for _, kw := range keywords {
if keyRel < 0 { idx := strings.Index(lower, kw)
return -1 if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
bestKeyIdx = idx
} }
keyIdx := offset + keyRel
start := strings.LastIndex(s[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
if !insideCodeFence(s[:start]) {
return start
}
offset = keyIdx + len("tool_calls")
} }
if bestKeyIdx < 0 {
return -1
}
start := strings.LastIndex(s[:bestKeyIdx], "{")
if start < 0 {
start = bestKeyIdx
}
return start
} }
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
@@ -183,13 +191,22 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
return "", nil, "", false return "", nil, "", false
} }
lower := strings.ToLower(captured) lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
keyIdx := -1
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
for _, kw := range keywords {
idx := strings.Index(lower, kw)
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
keyIdx = idx
}
}
if keyIdx < 0 { if keyIdx < 0 {
return "", nil, "", false return "", nil, "", false
} }
start := strings.LastIndex(captured[:keyIdx], "{") start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 { if start < 0 {
return "", nil, "", false start = keyIdx
} }
obj, end, ok := extractJSONObjectFrom(captured, start) obj, end, ok := extractJSONObjectFrom(captured, start)
if !ok { if !ok {
@@ -197,9 +214,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
} }
prefixPart := captured[:start] prefixPart := captured[:start]
suffixPart := captured[end:] suffixPart := captured[end:]
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames) parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
if len(parsed.Calls) == 0 { if len(parsed.Calls) == 0 {
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy { if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
@@ -207,6 +221,9 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
// consume it to avoid leaking raw tool_calls JSON to user content. // consume it to avoid leaking raw tool_calls JSON to user content.
return prefixPart, nil, suffixPart, true return prefixPart, nil, suffixPart, true
} }
// If it has obvious keywords but failed to parse even after loose repair,
// we still might want to intercept it if it looks like an attempt at tool call.
// For now, keep the original logic but rely on loose JSON repair.
return captured, nil, "", true return captured, nil, "", true
} }
return prefixPart, parsed.Calls, suffixPart, true return prefixPart, parsed.Calls, suffixPart, true

View File

@@ -19,9 +19,6 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if start < 0 { if start < 0 {
return nil return nil
} }
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx) certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
if hasMultiple { if hasMultiple {
state.disableDeltas = true state.disableDeltas = true

View File

@@ -7,17 +7,19 @@ import (
) )
type toolStreamSieveState struct { type toolStreamSieveState struct {
pending strings.Builder pending strings.Builder
capture strings.Builder capture strings.Builder
capturing bool capturing bool
recentTextTail string recentTextTail string
disableDeltas bool pendingToolRaw string
toolNameSent bool pendingToolCalls []util.ParsedToolCall
toolName string disableDeltas bool
toolArgsStart int toolNameSent bool
toolArgsSent int toolName string
toolArgsString bool toolArgsStart int
toolArgsDone bool toolArgsSent int
toolArgsString bool
toolArgsDone bool
} }
type toolStreamEvent struct { type toolStreamEvent struct {
@@ -32,7 +34,6 @@ type toolCallDelta struct {
Arguments string Arguments string
} }
const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256 const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() { func (s *toolStreamSieveState) resetIncrementalToolState() {

View File

@@ -16,6 +16,7 @@ type ConfigStore interface {
Accounts() []config.Account Accounts() []config.Account
FindAccount(identifier string) (config.Account, bool) FindAccount(identifier string) (config.Account, bool)
UpdateAccountToken(identifier, token string) error UpdateAccountToken(identifier, token string) error
UpdateAccountTestStatus(identifier, status string) error
Update(mutator func(*config.Config) error) error Update(mutator func(*config.Config) error) error
ExportJSONAndBase64() (string, string, error) ExportJSONAndBase64() (string, string, error)
IsEnvBacked() bool IsEnvBacked() bool
@@ -26,6 +27,7 @@ type ConfigStore interface {
RuntimeAccountMaxInflight() int RuntimeAccountMaxInflight() int
RuntimeAccountMaxQueue(defaultSize int) int RuntimeAccountMaxQueue(defaultSize int) int
RuntimeGlobalMaxInflight(defaultSize int) int RuntimeGlobalMaxInflight(defaultSize int) int
AutoDeleteSessions() bool
} }
type PoolController interface { type PoolController interface {
@@ -39,6 +41,8 @@ type DeepSeekCaller interface {
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
GetSessionCountForToken(ctx context.Context, token string) (*deepseek.SessionStats, error)
DeleteAllSessionsForToken(ctx context.Context, token string) error
} }
var _ ConfigStore = (*config.Store)(nil) var _ ConfigStore = (*config.Store)(nil)

View File

@@ -31,6 +31,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
pr.Get("/queue/status", h.queueStatus) pr.Get("/queue/status", h.queueStatus)
pr.Post("/accounts/test", h.testSingleAccount) pr.Post("/accounts/test", h.testSingleAccount)
pr.Post("/accounts/test-all", h.testAllAccounts) pr.Post("/accounts/test-all", h.testAllAccounts)
pr.Post("/accounts/sessions/delete-all", h.deleteAllSessions)
pr.Post("/import", h.batchImport) pr.Post("/import", h.batchImport)
pr.Post("/test", h.testAPI) pr.Post("/test", h.testAPI)
pr.Post("/vercel/sync", h.syncVercel) pr.Post("/vercel/sync", h.syncVercel)

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@@ -24,8 +25,21 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
pageSize = 100 pageSize = 100
} }
accounts := h.Store.Snapshot().Accounts accounts := h.Store.Snapshot().Accounts
total := len(accounts)
reverseAccounts(accounts) reverseAccounts(accounts)
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
if q != "" {
filtered := make([]config.Account, 0, len(accounts))
for _, acc := range accounts {
id := strings.ToLower(acc.Identifier())
if strings.Contains(id, q) ||
strings.Contains(strings.ToLower(acc.Email), q) ||
strings.Contains(strings.ToLower(acc.Mobile), q) {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
total := len(accounts)
totalPages := 1 totalPages := 1
if total > 0 { if total > 0 {
totalPages = (total + pageSize - 1) / pageSize totalPages = (total + pageSize - 1) / pageSize
@@ -56,6 +70,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
"has_password": acc.Password != "", "has_password": acc.Password != "",
"has_token": token != "", "has_token": token != "",
"token_preview": preview, "token_preview": preview,
"test_status": acc.TestStatus,
}) })
} }
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
@@ -70,11 +85,12 @@ func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
err := h.Store.Update(func(c *config.Config) error { err := h.Store.Update(func(c *config.Config) error {
mobileKey := config.CanonicalMobileKey(acc.Mobile)
for _, a := range c.Accounts { for _, a := range c.Accounts {
if acc.Email != "" && a.Email == acc.Email { if acc.Email != "" && a.Email == acc.Email {
return fmt.Errorf("邮箱已存在") return fmt.Errorf("邮箱已存在")
} }
if acc.Mobile != "" && a.Mobile == acc.Mobile { if mobileKey != "" && config.CanonicalMobileKey(a.Mobile) == mobileKey {
return fmt.Errorf("手机号已存在") return fmt.Errorf("手机号已存在")
} }
} }
@@ -91,6 +107,9 @@ func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
identifier := chi.URLParam(r, "identifier") identifier := chi.URLParam(r, "identifier")
if decoded, err := url.PathUnescape(identifier); err == nil {
identifier = decoded
}
err := h.Store.Update(func(c *config.Config) error { err := h.Store.Update(func(c *config.Config) error {
idx := -1 idx := -1
for i, a := range c.Accounts { for i, a := range c.Accounts {

View File

@@ -1,6 +1,7 @@
package admin package admin
import ( import (
"bytes"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -102,6 +103,45 @@ func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
} }
} }
func TestDeleteAccountSupportsEncodedPlusMobile(t *testing.T) {
h := newAdminTestHandler(t, `{
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
}`)
r := chi.NewRouter()
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape("+8613800138000"), nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
if got := len(h.Store.Accounts()); got != 0 {
t.Fatalf("expected account removed, remaining=%d", got)
}
}
func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) {
h := newAdminTestHandler(t, `{
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
}`)
r := chi.NewRouter()
r.Post("/admin/accounts", h.addAccount)
body := []byte(`{"mobile":"13800138000","password":"pwd2"}`)
req := httptest.NewRequest(http.MethodPost, "/admin/accounts", bytes.NewReader(body))
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
if got := len(h.Store.Accounts()); got != 1 {
t.Fatalf("expected no duplicate insert, got=%d", got)
}
}
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) { func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
h := newAdminTestHandler(t, `{ h := newAdminTestHandler(t, `{
"accounts":[ "accounts":[
@@ -117,6 +157,13 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
if accByMobile.Email != "u@example.com" { if accByMobile.Email != "u@example.com" {
t.Fatalf("unexpected account by mobile: %#v", accByMobile) t.Fatalf("unexpected account by mobile: %#v", accByMobile)
} }
accByMobileWithCountryCode, ok := findAccountByIdentifier(h.Store, "+8613800138000")
if !ok {
t.Fatal("expected find by +86 mobile")
}
if accByMobileWithCountryCode.Email != "u@example.com" {
t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode)
}
tokenOnlyID := "" tokenOnlyID := ""
for _, acc := range h.Store.Accounts() { for _, acc := range h.Store.Accounts() {

View File

@@ -88,7 +88,15 @@ func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int,
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any { func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
start := time.Now() start := time.Now()
result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model} identifier := acc.Identifier()
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model, "session_count": 0}
defer func() {
status := "failed"
if ok, _ := result["success"].(bool); ok {
status = "ok"
}
_ = h.Store.UpdateAccountTestStatus(identifier, status)
}()
token := strings.TrimSpace(acc.Token) token := strings.TrimSpace(acc.Token)
if token == "" { if token == "" {
newToken, err := h.DS.Login(ctx, acc) newToken, err := h.DS.Login(ctx, acc)
@@ -116,6 +124,13 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
return result return result
} }
} }
// 获取会话数量
sessionStats, sessionErr := h.DS.GetSessionCountForToken(ctx, token)
if sessionErr == nil && sessionStats != nil {
result["session_count"] = sessionStats.FirstPageCount
}
if strings.TrimSpace(message) == "" { if strings.TrimSpace(message) == "" {
result["success"] = true result["success"] = true
result["message"] = "API 测试成功(仅会话创建)" result["message"] = "API 测试成功(仅会话创建)"
@@ -202,3 +217,49 @@ func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
} }
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)}) writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
} }
func (h *Handler) deleteAllSessions(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
identifier, _ := req["identifier"].(string)
if strings.TrimSpace(identifier) == "" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识identifier / email / mobile"})
return
}
acc, ok := findAccountByIdentifier(h.Store, identifier)
if !ok {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
return
}
// 获取 token
token := strings.TrimSpace(acc.Token)
if token == "" {
newToken, err := h.DS.Login(r.Context(), acc)
if err != nil {
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "登录失败: " + err.Error()})
return
}
token = newToken
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
}
// 删除所有会话
err := h.DS.DeleteAllSessionsForToken(r.Context(), token)
if err != nil {
// token 可能过期,尝试重新登录并重试一次
newToken, loginErr := h.DS.Login(r.Context(), acc)
if loginErr != nil {
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + err.Error()})
return
}
token = newToken
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
if retryErr := h.DS.DeleteAllSessionsForToken(r.Context(), token); retryErr != nil {
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + retryErr.Error()})
return
}
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "删除成功"})
}

View File

@@ -0,0 +1,134 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
)
type testingDSMock struct {
loginCalls int
createSessionCalls int
getPowCalls int
callCompletionCalls int
deleteAllSessionsCalls int
deleteAllSessionsError error
deleteAllSessionsErrorOnce bool
}
func (m *testingDSMock) Login(_ context.Context, _ config.Account) (string, error) {
m.loginCalls++
return "new-token", nil
}
func (m *testingDSMock) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
m.createSessionCalls++
return "session-id", nil
}
func (m *testingDSMock) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
m.getPowCalls++
return "", errors.New("should not call GetPow in this test")
}
func (m *testingDSMock) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
m.callCompletionCalls++
return nil, errors.New("should not call CallCompletion in this test")
}
func (m *testingDSMock) DeleteAllSessionsForToken(_ context.Context, _ string) error {
m.deleteAllSessionsCalls++
if m.deleteAllSessionsError != nil {
err := m.deleteAllSessionsError
if m.deleteAllSessionsErrorOnce {
m.deleteAllSessionsError = nil
}
return err
}
return nil
}
func (m *testingDSMock) GetSessionCountForToken(_ context.Context, _ string) (*deepseek.SessionStats, error) {
return &deepseek.SessionStats{Success: true}, nil
}
func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":""}]}`)
store := config.LoadStore()
ds := &testingDSMock{}
h := &Handler{Store: store, DS: ds}
acc, ok := store.FindAccount("batch@example.com")
if !ok {
t.Fatal("expected test account")
}
result := h.testAccount(context.Background(), acc, "deepseek-chat", "")
if ok, _ := result["success"].(bool); !ok {
t.Fatalf("expected success=true, got %#v", result)
}
msg, _ := result["message"].(string)
if !strings.Contains(msg, "仅会话创建") {
t.Fatalf("expected session-only success message, got %q", msg)
}
if ds.loginCalls != 1 || ds.createSessionCalls != 1 {
t.Fatalf("unexpected Login/CreateSession calls: login=%d createSession=%d", ds.loginCalls, ds.createSessionCalls)
}
if ds.getPowCalls != 0 || ds.callCompletionCalls != 0 {
t.Fatalf("expected no completion flow calls, got getPow=%d callCompletion=%d", ds.getPowCalls, ds.callCompletionCalls)
}
updated, ok := store.FindAccount("batch@example.com")
if !ok {
t.Fatal("expected updated account")
}
if updated.Token != "new-token" {
t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token)
}
if updated.TestStatus != "ok" {
t.Fatalf("expected test status ok, got %q", updated.TestStatus)
}
}
func TestDeleteAllSessions_RetryWithReloginOnDeleteFailure(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":"expired-token"}]}`)
store := config.LoadStore()
ds := &testingDSMock{deleteAllSessionsError: errors.New("token expired"), deleteAllSessionsErrorOnce: true}
h := &Handler{Store: store, DS: ds}
req := httptest.NewRequest(http.MethodPost, "/delete-all", bytes.NewBufferString(`{"identifier":"batch@example.com"}`))
rec := httptest.NewRecorder()
h.deleteAllSessions(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", rec.Code)
}
var resp map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if ok, _ := resp["success"].(bool); !ok {
t.Fatalf("expected success response, got %#v", resp)
}
if ds.loginCalls != 1 {
t.Fatalf("expected relogin once, got %d", ds.loginCalls)
}
if ds.deleteAllSessionsCalls != 2 {
t.Fatalf("expected delete called twice, got %d", ds.deleteAllSessionsCalls)
}
updated, ok := store.FindAccount("batch@example.com")
if !ok {
t.Fatal("expected account")
}
if updated.Token != "new-token" {
t.Fatalf("expected refreshed token persisted, got %q", updated.Token)
}
}

View File

@@ -49,6 +49,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
next := c.Clone() next := c.Clone()
if mode == "replace" { if mode == "replace" {
next = incoming.Clone() next = incoming.Clone()
next.Accounts = normalizeAndDedupeAccounts(next.Accounts)
next.VercelSyncHash = c.VercelSyncHash next.VercelSyncHash = c.VercelSyncHash
next.VercelSyncTime = c.VercelSyncTime next.VercelSyncTime = c.VercelSyncTime
importedKeys = len(next.Keys) importedKeys = len(next.Keys)
@@ -73,17 +74,22 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
existingAccounts := map[string]struct{}{} existingAccounts := map[string]struct{}{}
for _, acc := range next.Accounts { for _, acc := range next.Accounts {
existingAccounts[acc.Identifier()] = struct{}{} acc = normalizeAccountForStorage(acc)
key := accountDedupeKey(acc)
if key != "" {
existingAccounts[key] = struct{}{}
}
} }
for _, acc := range incoming.Accounts { for _, acc := range incoming.Accounts {
id := acc.Identifier() acc = normalizeAccountForStorage(acc)
if id == "" { key := accountDedupeKey(acc)
if key == "" {
continue continue
} }
if _, ok := existingAccounts[id]; ok { if _, ok := existingAccounts[key]; ok {
continue continue
} }
existingAccounts[id] = struct{}{} existingAccounts[key] = struct{}{}
next.Accounts = append(next.Accounts, acc) next.Accounts = append(next.Accounts, acc)
importedAccounts++ importedAccounts++
} }

View File

@@ -25,17 +25,28 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
if accountsRaw, ok := req["accounts"].([]any); ok { if accountsRaw, ok := req["accounts"].([]any); ok {
existing := map[string]config.Account{} existing := map[string]config.Account{}
for _, a := range old.Accounts { for _, a := range old.Accounts {
existing[a.Identifier()] = a a = normalizeAccountForStorage(a)
key := accountDedupeKey(a)
if key != "" {
existing[key] = a
}
} }
seen := map[string]struct{}{}
accounts := make([]config.Account, 0, len(accountsRaw)) accounts := make([]config.Account, 0, len(accountsRaw))
for _, item := range accountsRaw { for _, item := range accountsRaw {
m, ok := item.(map[string]any) m, ok := item.(map[string]any)
if !ok { if !ok {
continue continue
} }
acc := toAccount(m) acc := normalizeAccountForStorage(toAccount(m))
id := acc.Identifier() key := accountDedupeKey(acc)
if prev, ok := existing[id]; ok { if key == "" {
continue
}
if _, ok := seen[key]; ok {
continue
}
if prev, ok := existing[key]; ok {
if strings.TrimSpace(acc.Password) == "" { if strings.TrimSpace(acc.Password) == "" {
acc.Password = prev.Password acc.Password = prev.Password
} }
@@ -43,6 +54,7 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
acc.Token = prev.Token acc.Token = prev.Token
} }
} }
seen[key] = struct{}{}
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
c.Accounts = accounts c.Accounts = accounts
@@ -138,20 +150,24 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
if accounts, ok := req["accounts"].([]any); ok { if accounts, ok := req["accounts"].([]any); ok {
existing := map[string]bool{} existing := map[string]bool{}
for _, a := range c.Accounts { for _, a := range c.Accounts {
existing[a.Identifier()] = true a = normalizeAccountForStorage(a)
key := accountDedupeKey(a)
if key != "" {
existing[key] = true
}
} }
for _, item := range accounts { for _, item := range accounts {
m, ok := item.(map[string]any) m, ok := item.(map[string]any)
if !ok { if !ok {
continue continue
} }
acc := toAccount(m) acc := normalizeAccountForStorage(toAccount(m))
id := acc.Identifier() key := accountDedupeKey(acc)
if id == "" || existing[id] { if key == "" || existing[key] {
continue continue
} }
c.Accounts = append(c.Accounts, acc) c.Accounts = append(c.Accounts, acc)
existing[id] = true existing[key] = true
importedAccounts++ importedAccounts++
} }
} }

View File

@@ -7,15 +7,30 @@ import (
"ds2api/internal/config" "ds2api/internal/config"
) )
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) { func boolFrom(v any) bool {
if v == nil {
return false
}
switch x := v.(type) {
case bool:
return x
case string:
return strings.ToLower(strings.TrimSpace(x)) == "true"
default:
return false
}
}
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, map[string]string, map[string]string, error) {
var ( var (
adminCfg *config.AdminConfig adminCfg *config.AdminConfig
runtimeCfg *config.RuntimeConfig runtimeCfg *config.RuntimeConfig
toolcallCfg *config.ToolcallConfig toolcallCfg *config.ToolcallConfig
respCfg *config.ResponsesConfig respCfg *config.ResponsesConfig
embCfg *config.EmbeddingsConfig embCfg *config.EmbeddingsConfig
claudeMap map[string]string autoDeleteCfg *config.AutoDeleteConfig
aliasMap map[string]string claudeMap map[string]string
aliasMap map[string]string
) )
if raw, ok := req["admin"].(map[string]any); ok { if raw, ok := req["admin"].(map[string]any); ok {
@@ -23,7 +38,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
if v, exists := raw["jwt_expire_hours"]; exists { if v, exists := raw["jwt_expire_hours"]; exists {
n := intFrom(v) n := intFrom(v)
if n < 1 || n > 720 { if n < 1 || n > 720 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
} }
cfg.JWTExpireHours = n cfg.JWTExpireHours = n
} }
@@ -35,26 +50,26 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
if v, exists := raw["account_max_inflight"]; exists { if v, exists := raw["account_max_inflight"]; exists {
n := intFrom(v) n := intFrom(v)
if n < 1 || n > 256 { if n < 1 || n > 256 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
} }
cfg.AccountMaxInflight = n cfg.AccountMaxInflight = n
} }
if v, exists := raw["account_max_queue"]; exists { if v, exists := raw["account_max_queue"]; exists {
n := intFrom(v) n := intFrom(v)
if n < 1 || n > 200000 { if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
} }
cfg.AccountMaxQueue = n cfg.AccountMaxQueue = n
} }
if v, exists := raw["global_max_inflight"]; exists { if v, exists := raw["global_max_inflight"]; exists {
n := intFrom(v) n := intFrom(v)
if n < 1 || n > 200000 { if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
} }
cfg.GlobalMaxInflight = n cfg.GlobalMaxInflight = n
} }
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
} }
runtimeCfg = cfg runtimeCfg = cfg
} }
@@ -67,7 +82,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
case "feature_match", "off": case "feature_match", "off":
cfg.Mode = mode cfg.Mode = mode
default: default:
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
} }
} }
if v, exists := raw["early_emit_confidence"]; exists { if v, exists := raw["early_emit_confidence"]; exists {
@@ -76,7 +91,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
case "high", "low", "off": case "high", "low", "off":
cfg.EarlyEmitConfidence = level cfg.EarlyEmitConfidence = level
default: default:
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
} }
} }
toolcallCfg = cfg toolcallCfg = cfg
@@ -87,7 +102,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
if v, exists := raw["store_ttl_seconds"]; exists { if v, exists := raw["store_ttl_seconds"]; exists {
n := intFrom(v) n := intFrom(v)
if n < 30 || n > 86400 { if n < 30 || n > 86400 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
} }
cfg.StoreTTLSeconds = n cfg.StoreTTLSeconds = n
} }
@@ -98,9 +113,6 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
cfg := &config.EmbeddingsConfig{} cfg := &config.EmbeddingsConfig{}
if v, exists := raw["provider"]; exists { if v, exists := raw["provider"]; exists {
p := strings.TrimSpace(fmt.Sprintf("%v", v)) p := strings.TrimSpace(fmt.Sprintf("%v", v))
if p == "" {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
}
cfg.Provider = p cfg.Provider = p
} }
embCfg = cfg embCfg = cfg
@@ -130,5 +142,13 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
} }
} }
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil if raw, ok := req["auto_delete"].(map[string]any); ok {
cfg := &config.AutoDeleteConfig{}
if v, exists := raw["sessions"]; exists {
cfg.Sessions = boolFrom(v)
}
autoDeleteCfg = cfg
}
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, autoDeleteCfg, claudeMap, aliasMap, nil
} }

View File

@@ -28,6 +28,7 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
"toolcall": snap.Toolcall, "toolcall": snap.Toolcall,
"responses": snap.Responses, "responses": snap.Responses,
"embeddings": snap.Embeddings, "embeddings": snap.Embeddings,
"auto_delete": snap.AutoDelete,
"claude_mapping": settingsClaudeMapping(snap), "claude_mapping": settingsClaudeMapping(snap),
"model_aliases": snap.ModelAliases, "model_aliases": snap.ModelAliases,
"env_backed": h.Store.IsEnvBacked(), "env_backed": h.Store.IsEnvBacked(),

View File

@@ -265,3 +265,57 @@ func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) {
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime) t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
} }
} }
func TestConfigImportMergeDedupesMobileAliases(t *testing.T) {
h := newAdminTestHandler(t, `{
"keys":["k1"],
"accounts":[{"mobile":"+8613800138000","password":"p1"}]
}`)
merge := map[string]any{
"mode": "merge",
"config": map[string]any{
"accounts": []any{
map[string]any{"mobile": "13800138000", "password": "p2"},
},
},
}
b, _ := json.Marshal(merge)
req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b))
rec := httptest.NewRecorder()
h.configImport(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
}
if got := len(h.Store.Accounts()); got != 1 {
t.Fatalf("expected merge dedupe by canonical mobile, got=%d", got)
}
}
func TestUpdateConfigDedupesMobileAliases(t *testing.T) {
h := newAdminTestHandler(t, `{
"keys":["k1"],
"accounts":[{"mobile":"+8613800138000","password":"old"}]
}`)
reqBody := map[string]any{
"accounts": []any{
map[string]any{"mobile": "+8613800138000"},
map[string]any{"mobile": "13800138000"},
},
}
b, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
rec := httptest.NewRecorder()
h.updateConfig(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
}
accounts := h.Store.Accounts()
if len(accounts) != 1 {
t.Fatalf("expected update dedupe by canonical mobile, got=%d", len(accounts))
}
if accounts[0].Identifier() != "+8613800138000" {
t.Fatalf("unexpected identifier: %q", accounts[0].Identifier())
}
}

View File

@@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
return return
} }
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req) adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
if err != nil { if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return return
@@ -60,6 +60,9 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" { if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider) c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
} }
if autoDeleteCfg != nil {
c.AutoDelete.Sessions = autoDeleteCfg.Sessions
}
if claudeMap != nil { if claudeMap != nil {
c.ClaudeMapping = claudeMap c.ClaudeMapping = claudeMap
c.ClaudeModelMap = nil c.ClaudeModelMap = nil

View File

@@ -59,9 +59,11 @@ func toStringSlice(v any) ([]string, bool) {
} }
func toAccount(m map[string]any) config.Account { func toAccount(m map[string]any) config.Account {
email := fieldString(m, "email")
mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile"))
return config.Account{ return config.Account{
Email: fieldString(m, "email"), Email: email,
Mobile: fieldString(m, "mobile"), Mobile: mobile,
Password: fieldString(m, "password"), Password: fieldString(m, "password"),
Token: fieldString(m, "token"), Token: fieldString(m, "token"),
} }
@@ -90,12 +92,52 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
if strings.TrimSpace(acc.Email) == id { if strings.TrimSpace(acc.Email) == id {
return true return true
} }
if strings.TrimSpace(acc.Mobile) == id { if mobileKey := config.CanonicalMobileKey(id); mobileKey != "" && mobileKey == config.CanonicalMobileKey(acc.Mobile) {
return true return true
} }
return acc.Identifier() == id return acc.Identifier() == id
} }
func normalizeAccountForStorage(acc config.Account) config.Account {
acc.Email = strings.TrimSpace(acc.Email)
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
return acc
}
func accountDedupeKey(acc config.Account) string {
if email := strings.TrimSpace(acc.Email); email != "" {
return "email:" + email
}
if mobile := config.CanonicalMobileKey(acc.Mobile); mobile != "" {
return "mobile:" + mobile
}
if id := strings.TrimSpace(acc.Identifier()); id != "" {
return "id:" + id
}
return ""
}
func normalizeAndDedupeAccounts(accounts []config.Account) []config.Account {
if len(accounts) == 0 {
return nil
}
out := make([]config.Account, 0, len(accounts))
seen := make(map[string]struct{}, len(accounts))
for _, acc := range accounts {
acc = normalizeAccountForStorage(acc)
key := accountDedupeKey(acc)
if key == "" {
continue
}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, acc)
}
return out
}
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) { func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
id := strings.TrimSpace(identifier) id := strings.TrimSpace(identifier)
if id == "" { if id == "" {

View File

@@ -182,7 +182,7 @@ func TestToAccountAllFields(t *testing.T) {
if acc.Email != "user@test.com" { if acc.Email != "user@test.com" {
t.Fatalf("unexpected email: %q", acc.Email) t.Fatalf("unexpected email: %q", acc.Email)
} }
if acc.Mobile != "13800138000" { if acc.Mobile != "+8613800138000" {
t.Fatalf("unexpected mobile: %q", acc.Mobile) t.Fatalf("unexpected mobile: %q", acc.Mobile)
} }
if acc.Password != "secret" { if acc.Password != "secret" {

View File

@@ -5,6 +5,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings"
"testing" "testing"
"ds2api/internal/sse" "ds2api/internal/sse"
@@ -67,20 +68,36 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
var fixture struct { var fixture struct {
Text string `json:"text"` Text string `json:"text"`
ToolNames []string `json:"tool_names"` ToolNames []string `json:"tool_names"`
Mode string `json:"mode"`
} }
mustLoadJSON(t, fixturePath, &fixture) mustLoadJSON(t, fixturePath, &fixture)
var expected struct { var expected struct {
Calls []util.ParsedToolCall `json:"calls"` Calls []util.ParsedToolCall `json:"calls"`
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
RejectedByPolicy bool `json:"rejectedByPolicy"`
RejectedToolNames []string `json:"rejectedToolNames"`
} }
mustLoadJSON(t, expectedPath, &expected) mustLoadJSON(t, expectedPath, &expected)
got := util.ParseToolCalls(fixture.Text, fixture.ToolNames) var got util.ToolCallParseResult
if len(got) == 0 && len(expected.Calls) == 0 { switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
continue case "standalone":
got = util.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
default:
got = util.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
} }
if !reflect.DeepEqual(got, expected.Calls) { if got.Calls == nil {
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls) got.Calls = []util.ParsedToolCall{}
}
if got.RejectedToolNames == nil {
got.RejectedToolNames = []string{}
}
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
got.RejectedByPolicy != expected.RejectedByPolicy ||
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
} }
} }
} }

View File

@@ -10,8 +10,8 @@ func (a Account) Identifier() string {
if strings.TrimSpace(a.Email) != "" { if strings.TrimSpace(a.Email) != "" {
return strings.TrimSpace(a.Email) return strings.TrimSpace(a.Email)
} }
if strings.TrimSpace(a.Mobile) != "" { if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" {
return strings.TrimSpace(a.Mobile) return mobile
} }
// Backward compatibility: old configs may contain token-only accounts. // Backward compatibility: old configs may contain token-only accounts.
// Use a stable non-sensitive synthetic id so they can still join the pool. // Use a stable non-sensitive synthetic id so they can still join the pool.

View File

@@ -47,6 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
if strings.TrimSpace(c.Embeddings.Provider) != "" { if strings.TrimSpace(c.Embeddings.Provider) != "" {
m["embeddings"] = c.Embeddings m["embeddings"] = c.Embeddings
} }
m["auto_delete"] = c.AutoDelete
if c.VercelSyncHash != "" { if c.VercelSyncHash != "" {
m["_vercel_sync_hash"] = c.VercelSyncHash m["_vercel_sync_hash"] = c.VercelSyncHash
} }
@@ -108,6 +109,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
if err := json.Unmarshal(v, &c.Embeddings); err != nil { if err := json.Unmarshal(v, &c.Embeddings); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err) return fmt.Errorf("invalid field %q: %w", k, err)
} }
case "auto_delete":
if err := json.Unmarshal(v, &c.AutoDelete); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_hash": case "_vercel_sync_hash":
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err) return fmt.Errorf("invalid field %q: %w", k, err)
@@ -141,6 +146,7 @@ func (c Config) Clone() Config {
Toolcall: c.Toolcall, Toolcall: c.Toolcall,
Responses: c.Responses, Responses: c.Responses,
Embeddings: c.Embeddings, Embeddings: c.Embeddings,
AutoDelete: c.AutoDelete,
VercelSyncHash: c.VercelSyncHash, VercelSyncHash: c.VercelSyncHash,
VercelSyncTime: c.VercelSyncTime, VercelSyncTime: c.VercelSyncTime,
AdditionalFields: map[string]any{}, AdditionalFields: map[string]any{},

View File

@@ -12,16 +12,18 @@ type Config struct {
Toolcall ToolcallConfig `json:"toolcall,omitempty"` Toolcall ToolcallConfig `json:"toolcall,omitempty"`
Responses ResponsesConfig `json:"responses,omitempty"` Responses ResponsesConfig `json:"responses,omitempty"`
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"` Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"` AutoDelete AutoDeleteConfig `json:"auto_delete"`
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"` VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
AdditionalFields map[string]any `json:"-"` AdditionalFields map[string]any `json:"-"`
} }
type Account struct { type Account struct {
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
Mobile string `json:"mobile,omitempty"` Mobile string `json:"mobile,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`
TestStatus string `json:"test_status,omitempty"`
} }
type CompatConfig struct { type CompatConfig struct {
@@ -52,3 +54,7 @@ type ResponsesConfig struct {
type EmbeddingsConfig struct { type EmbeddingsConfig struct {
Provider string `json:"provider,omitempty"` Provider string `json:"provider,omitempty"`
} }
type AutoDeleteConfig struct {
Sessions bool `json:"sessions"`
}

View File

@@ -202,7 +202,7 @@ func TestConfigCloneNilMaps(t *testing.T) {
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) { func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
acc := Account{Mobile: "13800138000", Token: "tok"} acc := Account{Mobile: "13800138000", Token: "tok"}
if acc.Identifier() != "13800138000" { if acc.Identifier() != "+8613800138000" {
t.Fatalf("expected mobile identifier, got %q", acc.Identifier()) t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
} }
} }

82
internal/config/mobile.go Normal file
View File

@@ -0,0 +1,82 @@
package config
import "strings"
// NormalizeMobileForStorage normalizes user input to a stable storage format.
// It keeps existing country codes and auto-prefixes mainland China numbers with +86.
func NormalizeMobileForStorage(raw string) string {
digits, hasPlus := extractMobileDigits(raw)
if digits == "" {
return ""
}
if hasPlus {
return "+" + digits
}
if isChinaMobileWithCountryCode(digits) {
return "+86" + digits[2:]
}
if isChinaMainlandMobileDigits(digits) {
return "+86" + digits
}
// For non-China numbers without a leading +, preserve semantics by adding it.
return "+" + digits
}
// CanonicalMobileKey returns the comparison key used by dedupe/matching logic.
func CanonicalMobileKey(raw string) string {
return NormalizeMobileForStorage(raw)
}
func extractMobileDigits(raw string) (digits string, hasPlus bool) {
s := strings.TrimSpace(raw)
if s == "" {
return "", false
}
for _, r := range s {
switch {
case r >= '0' && r <= '9':
goto collect
case isMobileSeparator(r):
continue
case r == '+':
hasPlus = true
goto collect
default:
goto collect
}
}
collect:
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
if r >= '0' && r <= '9' {
b.WriteRune(r)
}
}
return b.String(), hasPlus
}
func isChinaMainlandMobileDigits(digits string) bool {
if len(digits) != 11 || digits[0] != '1' {
return false
}
return digits[1] >= '3' && digits[1] <= '9'
}
func isChinaMobileWithCountryCode(digits string) bool {
if len(digits) != 13 || !strings.HasPrefix(digits, "86") {
return false
}
return isChinaMainlandMobileDigits(digits[2:])
}
func isMobileSeparator(r rune) bool {
switch r {
case ' ', '\t', '\n', '\r', '-', '(', ')', '.', '/':
return true
default:
return false
}
}

View File

@@ -0,0 +1,36 @@
package config
import "testing"
func TestNormalizeMobileForStorageChinaMainlandAddsPlus86(t *testing.T) {
if got := NormalizeMobileForStorage("13800138000"); got != "+8613800138000" {
t.Fatalf("got %q", got)
}
}
func TestNormalizeMobileForStorageChinaWithCountryCode(t *testing.T) {
if got := NormalizeMobileForStorage("8613800138000"); got != "+8613800138000" {
t.Fatalf("got %q", got)
}
}
func TestNormalizeMobileForStorageKeepsExistingCountryCode(t *testing.T) {
if got := NormalizeMobileForStorage(" +1 (415) 555-2671 "); got != "+14155552671" {
t.Fatalf("got %q", got)
}
}
func TestCanonicalMobileKeyMatchesChinaAliases(t *testing.T) {
a := CanonicalMobileKey("+8613800138000")
b := CanonicalMobileKey("13800138000")
c := CanonicalMobileKey("86 13800138000")
if a == "" || a != b || b != c {
t.Fatalf("alias mismatch: a=%q b=%q c=%q", a, b, c)
}
}
func TestCanonicalMobileKeyEmptyForInvalidInput(t *testing.T) {
if got := CanonicalMobileKey("() --"); got != "" {
t.Fatalf("got %q", got)
}
}

View File

@@ -97,6 +97,18 @@ func (s *Store) FindAccount(identifier string) (Account, bool) {
return Account{}, false return Account{}, false
} }
func (s *Store) UpdateAccountTestStatus(identifier, status string) error {
identifier = strings.TrimSpace(identifier)
s.mu.Lock()
defer s.mu.Unlock()
idx, ok := s.findAccountIndexLocked(identifier)
if !ok {
return errors.New("account not found")
}
s.cfg.Accounts[idx].TestStatus = status
return s.saveLocked()
}
func (s *Store) UpdateAccountToken(identifier, token string) error { func (s *Store) UpdateAccountToken(identifier, token string) error {
identifier = strings.TrimSpace(identifier) identifier = strings.TrimSpace(identifier)
s.mu.Lock() s.mu.Lock()

View File

@@ -165,3 +165,9 @@ func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
} }
return defaultSize return defaultSize
} }
func (s *Store) AutoDeleteSessions() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cfg.AutoDelete.Sessions
}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"unicode"
"ds2api/internal/auth" "ds2api/internal/auth"
"ds2api/internal/config" "ds2api/internal/config"
@@ -20,8 +21,9 @@ func (c *Client) Login(ctx context.Context, acc config.Account) (string, error)
if email := strings.TrimSpace(acc.Email); email != "" { if email := strings.TrimSpace(acc.Email); email != "" {
payload["email"] = email payload["email"] = email
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" { } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
payload["mobile"] = mobile loginMobile, areaCode := normalizeMobileForLogin(mobile)
payload["area_code"] = nil payload["mobile"] = loginMobile
payload["area_code"] = areaCode
} else { } else {
return "", errors.New("missing email/mobile") return "", errors.New("missing email/mobile")
} }
@@ -60,8 +62,8 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
attempts++ attempts++
continue continue
} }
code := intFrom(resp["code"]) code, bizCode, msg, bizMsg := extractResponseStatus(resp)
if status == http.StatusOK && code == 0 { if status == http.StatusOK && code == 0 && bizCode == 0 {
data, _ := resp["data"].(map[string]any) data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any) bizData, _ := data["biz_data"].(map[string]any)
sessionID, _ := bizData["id"].(string) sessionID, _ := bizData["id"].(string)
@@ -69,10 +71,9 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
return sessionID, nil return sessionID, nil
} }
} }
msg, _ := resp["msg"].(string) config.Logger.Warn("[create_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
if a.UseConfigToken { if a.UseConfigToken {
if isTokenInvalid(status, code, msg) && !refreshed { if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
if c.Auth.RefreshToken(ctx, a) { if c.Auth.RefreshToken(ctx, a) {
refreshed = true refreshed = true
continue continue
@@ -94,6 +95,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
maxAttempts = c.maxRetries maxAttempts = c.maxRetries
} }
attempts := 0 attempts := 0
refreshed := false
for attempts < maxAttempts { for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken) headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"}) resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
@@ -102,8 +104,8 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
attempts++ attempts++
continue continue
} }
code := intFrom(resp["code"]) code, bizCode, msg, bizMsg := extractResponseStatus(resp)
if status == http.StatusOK && code == 0 { if status == http.StatusOK && code == 0 && bizCode == 0 {
data, _ := resp["data"].(map[string]any) data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any) bizData, _ := data["biz_data"].(map[string]any)
challenge, _ := bizData["challenge"].(map[string]any) challenge, _ := bizData["challenge"].(map[string]any)
@@ -114,15 +116,16 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
} }
return BuildPowHeader(challenge, answer) return BuildPowHeader(challenge, answer)
} }
msg, _ := resp["msg"].(string) config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
if a.UseConfigToken { if a.UseConfigToken {
if isTokenInvalid(status, code, msg) { if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
if c.Auth.RefreshToken(ctx, a) { if c.Auth.RefreshToken(ctx, a) {
refreshed = true
continue continue
} }
} }
if c.Auth.SwitchAccount(ctx, a) { if c.Auth.SwitchAccount(ctx, a) {
refreshed = false
attempts++ attempts++
continue continue
} }
@@ -141,13 +144,55 @@ func (c *Client) authHeaders(token string) map[string]string {
return headers return headers
} }
func isTokenInvalid(status int, code int, msg string) bool { func isTokenInvalid(status int, code int, bizCode int, msg string, bizMsg string) bool {
msg = strings.ToLower(msg) msg = strings.ToLower(strings.TrimSpace(msg) + " " + strings.TrimSpace(bizMsg))
if status == http.StatusUnauthorized || status == http.StatusForbidden { if status == http.StatusUnauthorized || status == http.StatusForbidden {
return true return true
} }
if code == 40001 || code == 40002 || code == 40003 { if code == 40001 || code == 40002 || code == 40003 || bizCode == 40001 || bizCode == 40002 || bizCode == 40003 {
return true return true
} }
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") return strings.Contains(msg, "token") ||
strings.Contains(msg, "unauthorized") ||
strings.Contains(msg, "expired") ||
strings.Contains(msg, "not login") ||
strings.Contains(msg, "login required") ||
strings.Contains(msg, "invalid jwt")
}
func extractResponseStatus(resp map[string]any) (code int, bizCode int, msg string, bizMsg string) {
code = intFrom(resp["code"])
msg, _ = resp["msg"].(string)
data, _ := resp["data"].(map[string]any)
bizCode = intFrom(data["biz_code"])
bizMsg, _ = data["biz_msg"].(string)
if strings.TrimSpace(bizMsg) == "" {
if bizData, ok := data["biz_data"].(map[string]any); ok {
bizMsg, _ = bizData["msg"].(string)
}
}
return code, bizCode, msg, bizMsg
}
func normalizeMobileForLogin(raw string) (mobile string, areaCode any) {
s := strings.TrimSpace(raw)
if s == "" {
return "", nil
}
hasPlus := strings.HasPrefix(s, "+")
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
if unicode.IsDigit(r) {
b.WriteRune(r)
}
}
digits := b.String()
if digits == "" {
return "", nil
}
if (hasPlus || strings.HasPrefix(digits, "86")) && strings.HasPrefix(digits, "86") && len(digits) == 13 {
return digits[2:], nil
}
return digits, nil
} }

View File

@@ -0,0 +1,33 @@
package deepseek
import "testing"
func TestNormalizeMobileForLogin_ChinaWithPlus86(t *testing.T) {
mobile, areaCode := normalizeMobileForLogin("+8613800138000")
if mobile != "13800138000" {
t.Fatalf("unexpected mobile: %q", mobile)
}
if areaCode != nil {
t.Fatalf("expected nil areaCode, got %#v", areaCode)
}
}
func TestNormalizeMobileForLogin_ChinaWith86Prefix(t *testing.T) {
mobile, areaCode := normalizeMobileForLogin("8613800138000")
if mobile != "13800138000" {
t.Fatalf("unexpected mobile: %q", mobile)
}
if areaCode != nil {
t.Fatalf("expected nil areaCode, got %#v", areaCode)
}
}
func TestNormalizeMobileForLogin_KeepPlainDigits(t *testing.T) {
mobile, areaCode := normalizeMobileForLogin("13800138000")
if mobile != "13800138000" {
t.Fatalf("unexpected mobile: %q", mobile)
}
if areaCode != nil {
t.Fatalf("expected nil areaCode, got %#v", areaCode)
}
}

View File

@@ -62,3 +62,51 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
} }
return out, resp.StatusCode, nil return out, resp.StatusCode, nil
} }
func (c *Client) getJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, error) {
body, status, err := c.getJSONWithStatus(ctx, doer, url, headers)
if err != nil {
return nil, err
}
if status == 0 {
return nil, errors.New("request failed")
}
return body, nil
}
func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, 0, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := doer.Do(req)
if err != nil {
config.Logger.Warn("[deepseek] fingerprint GET request failed, fallback to std transport", "url", url, "error", err)
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if reqErr != nil {
return nil, 0, err
}
for k, v := range headers {
req2.Header.Set(k, v)
}
resp, err = c.fallback.Do(req2)
if err != nil {
return nil, 0, err
}
}
defer resp.Body.Close()
payloadBytes, err := readResponseBody(resp)
if err != nil {
return nil, resp.StatusCode, err
}
out := map[string]any{}
if len(payloadBytes) > 0 {
if err := json.Unmarshal(payloadBytes, &out); err != nil {
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
}
}
return out, resp.StatusCode, nil
}

View File

@@ -0,0 +1,256 @@
package deepseek
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"ds2api/internal/auth"
"ds2api/internal/config"
)
// SessionInfo 会话信息
type SessionInfo struct {
ID string `json:"id"`
Title string `json:"title"`
TitleType string `json:"title_type"`
Pinned bool `json:"pinned"`
UpdatedAt float64 `json:"updated_at"`
}
// SessionStats 会话统计结果
type SessionStats struct {
AccountID string // 账号标识 (email 或 mobile)
FirstPageCount int // 第一页会话数量(当 HasMore 为 true 时,真实总数可能更大)
PinnedCount int // 置顶会话数量
HasMore bool // 是否还有更多页
Success bool // 请求是否成功
ErrorMessage string // 错误信息
}
// GetSessionCount 获取单个账号的会话数量
func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (*SessionStats, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
stats := &SessionStats{
AccountID: a.AccountID,
}
attempts := 0
refreshed := false
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
// 构建请求 URL
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
if err != nil {
config.Logger.Warn("[get_session_count] request error", "error", err, "account", a.AccountID)
attempts++
continue
}
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
if status == http.StatusOK && code == 0 && bizCode == 0 {
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
chatSessions, _ := bizData["chat_sessions"].([]any)
hasMore, _ := bizData["has_more"].(bool)
stats.FirstPageCount = len(chatSessions)
stats.HasMore = hasMore
stats.Success = true
// 统计置顶会话数量
for _, session := range chatSessions {
if s, ok := session.(map[string]any); ok {
if pinned, ok := s["pinned"].(bool); ok && pinned {
stats.PinnedCount++
}
}
}
return stats, nil
}
stats.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
config.Logger.Warn("[get_session_count] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "account", a.AccountID)
if a.UseConfigToken {
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
if c.Auth.RefreshToken(ctx, a) {
refreshed = true
continue
}
}
if c.Auth.SwitchAccount(ctx, a) {
refreshed = false
attempts++
continue
}
}
attempts++
}
stats.Success = false
stats.ErrorMessage = "get session count failed after retries"
return stats, errors.New(stats.ErrorMessage)
}
// GetSessionCountForToken 直接使用 token 获取会话数量(直通模式)
func (c *Client) GetSessionCountForToken(ctx context.Context, token string) (*SessionStats, error) {
headers := c.authHeaders(token)
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
if err != nil {
return nil, err
}
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
if status != http.StatusOK || code != 0 || bizCode != 0 {
if strings.TrimSpace(bizMsg) != "" {
msg = bizMsg
}
return nil, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
}
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
chatSessions, _ := bizData["chat_sessions"].([]any)
hasMore, _ := bizData["has_more"].(bool)
stats := &SessionStats{
FirstPageCount: len(chatSessions),
HasMore: hasMore,
Success: true,
}
// 统计置顶会话数量
for _, session := range chatSessions {
if s, ok := session.(map[string]any); ok {
if pinned, ok := s["pinned"].(bool); ok && pinned {
stats.PinnedCount++
}
}
}
return stats, nil
}
// GetSessionCountAll 获取所有账号的会话数量统计
func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
accounts := c.Store.Accounts()
results := make([]*SessionStats, 0, len(accounts))
for _, acc := range accounts {
token := acc.Token
accountID := acc.Email
if accountID == "" {
accountID = acc.Mobile
}
// 如果没有 token尝试登录获取
if token == "" {
var err error
token, err = c.Login(ctx, acc)
if err != nil {
results = append(results, &SessionStats{
AccountID: accountID,
Success: false,
ErrorMessage: fmt.Sprintf("login failed: %v", err),
})
continue
}
}
stats, err := c.GetSessionCountForToken(ctx, token)
if err != nil {
results = append(results, &SessionStats{
AccountID: accountID,
Success: false,
ErrorMessage: err.Error(),
})
continue
}
stats.AccountID = accountID
results = append(results, stats)
}
return results
}
// FetchSessionPage 获取会话列表(支持分页)
func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, cursor string) ([]SessionInfo, bool, error) {
headers := c.authHeaders(a.DeepSeekToken)
// 构建请求 URL
params := url.Values{}
params.Set("lte_cursor.pinned", "false")
if cursor != "" {
params.Set("lte_cursor", cursor)
}
reqURL := DeepSeekFetchSessionURL + "?" + params.Encode()
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
if err != nil {
return nil, false, err
}
code := intFrom(resp["code"])
if status != http.StatusOK || code != 0 {
msg, _ := resp["msg"].(string)
return nil, false, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
}
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
chatSessions, _ := bizData["chat_sessions"].([]any)
hasMore, _ := bizData["has_more"].(bool)
sessions := make([]SessionInfo, 0, len(chatSessions))
for _, s := range chatSessions {
if m, ok := s.(map[string]any); ok {
session := SessionInfo{
ID: stringFromMap(m, "id"),
Title: stringFromMap(m, "title"),
TitleType: stringFromMap(m, "title_type"),
Pinned: boolFromMap(m, "pinned"),
UpdatedAt: floatFromMap(m, "updated_at"),
}
sessions = append(sessions, session)
}
}
return sessions, hasMore, nil
}
// 辅助函数
func stringFromMap(m map[string]any, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func boolFromMap(m map[string]any, key string) bool {
if v, ok := m[key].(bool); ok {
return v
}
return false
}
func floatFromMap(m map[string]any, key string) float64 {
if v, ok := m[key].(float64); ok {
return v
}
return 0
}

View File

@@ -0,0 +1,155 @@
package deepseek
import (
"context"
"errors"
"fmt"
"net/http"
"ds2api/internal/auth"
"ds2api/internal/config"
)
// DeleteSessionResult 删除会话结果
type DeleteSessionResult struct {
SessionID string // 会话 ID
Success bool // 是否成功
ErrorMessage string // 错误信息
}
// DeleteSession 删除单个会话
func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, sessionID string, maxAttempts int) (*DeleteSessionResult, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
result := &DeleteSessionResult{
SessionID: sessionID,
}
if sessionID == "" {
result.ErrorMessage = "session_id is required"
return result, errors.New(result.ErrorMessage)
}
attempts := 0
refreshed := false
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
payload := map[string]any{
"chat_session_id": sessionID,
}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
if err != nil {
config.Logger.Warn("[delete_session] request error", "error", err, "session_id", sessionID)
attempts++
continue
}
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
if status == http.StatusOK && code == 0 && bizCode == 0 {
result.Success = true
return result, nil
}
result.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
config.Logger.Warn("[delete_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "session_id", sessionID)
if a.UseConfigToken {
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
if c.Auth.RefreshToken(ctx, a) {
refreshed = true
continue
}
}
if c.Auth.SwitchAccount(ctx, a) {
refreshed = false
attempts++
continue
}
}
attempts++
}
result.Success = false
result.ErrorMessage = "delete session failed after retries"
return result, errors.New(result.ErrorMessage)
}
// DeleteSessionForToken 直接使用 token 删除会话(直通模式)
func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*DeleteSessionResult, error) {
result := &DeleteSessionResult{
SessionID: sessionID,
}
if sessionID == "" {
result.ErrorMessage = "session_id is required"
return result, errors.New(result.ErrorMessage)
}
headers := c.authHeaders(token)
payload := map[string]any{
"chat_session_id": sessionID,
}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
if err != nil {
result.ErrorMessage = err.Error()
return result, err
}
code := intFrom(resp["code"])
if status != http.StatusOK || code != 0 {
msg, _ := resp["msg"].(string)
result.ErrorMessage = fmt.Sprintf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
return result, errors.New(result.ErrorMessage)
}
result.Success = true
return result, nil
}
// DeleteAllSessions 删除所有会话(谨慎使用)
func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) error {
headers := c.authHeaders(a.DeepSeekToken)
payload := map[string]any{}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
if err != nil {
config.Logger.Warn("[delete_all_sessions] request error", "error", err)
return err
}
code := intFrom(resp["code"])
if status != http.StatusOK || code != 0 {
msg, _ := resp["msg"].(string)
config.Logger.Warn("[delete_all_sessions] failed", "status", status, "code", code, "msg", msg)
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
}
return nil
}
// DeleteAllSessionsForToken 直接使用 token 删除所有会话(直通模式)
func (c *Client) DeleteAllSessionsForToken(ctx context.Context, token string) error {
headers := c.authHeaders(token)
payload := map[string]any{}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
if err != nil {
config.Logger.Warn("[delete_all_sessions_for_token] request error", "error", err)
return err
}
code := intFrom(resp["code"])
if status != http.StatusOK || code != 0 {
msg, _ := resp["msg"].(string)
config.Logger.Warn("[delete_all_sessions_for_token] failed", "status", status, "code", code, "msg", msg)
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
}
return nil
}

View File

@@ -11,6 +11,9 @@ const (
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create" DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge" DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion" DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page"
DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete"
DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"
) )
var defaultBaseHeaders = map[string]string{ var defaultBaseHeaders = map[string]string{

View File

@@ -9,6 +9,9 @@ import (
func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any {
detected := util.ParseToolCalls(finalText, toolNames) detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && finalText == "" && finalThinking != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
content := make([]map[string]any, 0, 4) content := make([]map[string]any, 0, 4)
if finalThinking != "" { if finalThinking != "" {
content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking})

View File

@@ -0,0 +1,62 @@
package claude
import "testing"
func TestBuildMessageResponseDetectsToolCallsFromThinkingFallback(t *testing.T) {
resp := BuildMessageResponse(
"msg_1",
"claude-sonnet-4-5",
[]any{map[string]any{"role": "user", "content": "hi"}},
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
"",
[]string{"search"},
)
if resp["stop_reason"] != "tool_use" {
t.Fatalf("expected stop_reason=tool_use, got=%#v", resp["stop_reason"])
}
content, _ := resp["content"].([]map[string]any)
if len(content) < 2 {
t.Fatalf("expected thinking + tool_use content blocks, got=%#v", resp["content"])
}
last := content[len(content)-1]
if last["type"] != "tool_use" {
t.Fatalf("expected last content block tool_use, got=%#v", last["type"])
}
if last["name"] != "search" {
t.Fatalf("expected tool name search, got=%#v", last["name"])
}
}
func TestBuildMessageResponseSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
resp := BuildMessageResponse(
"msg_1",
"claude-sonnet-4-5",
[]any{map[string]any{"role": "user", "content": "hi"}},
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
"normal answer",
[]string{"search"},
)
if resp["stop_reason"] != "end_turn" {
t.Fatalf("expected stop_reason=end_turn, got=%#v", resp["stop_reason"])
}
content, _ := resp["content"].([]map[string]any)
foundText := false
foundTool := false
for _, block := range content {
if block["type"] == "text" && block["text"] == "normal answer" {
foundText = true
}
if block["type"] == "tool_use" {
foundTool = true
}
}
if !foundText {
t.Fatalf("expected text block with finalText, got=%#v", resp["content"])
}
if foundTool {
t.Fatalf("unexpected tool_use block when finalText exists, got=%#v", resp["content"])
}
}

View File

@@ -8,15 +8,15 @@ import (
) )
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := util.ParseToolCalls(finalText, toolNames) detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
finishReason := "stop" finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText} messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" { if strings.TrimSpace(finalThinking) != "" {
messageObj["reasoning_content"] = finalThinking messageObj["reasoning_content"] = finalThinking
} }
if len(detected) > 0 { if len(detected.Calls) > 0 {
finishReason = "tool_calls" finishReason = "tool_calls"
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected.Calls)
messageObj["content"] = nil messageObj["content"] = nil
} }

View File

@@ -11,17 +11,14 @@ import (
) )
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
// Align responses tool-call semantics with chat/completions: // Strict mode: only standalone, structured tool-call payloads are treated
// mixed prose + tool_call payloads should still be interpreted as tool calls. // as executable tool calls.
detected := util.ParseToolCalls(finalText, toolNames) detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
exposedOutputText := finalText exposedOutputText := finalText
output := make([]any, 0, 2) output := make([]any, 0, 2)
if len(detected) > 0 { if len(detected.Calls) > 0 {
exposedOutputText = "" exposedOutputText = ""
output = append(output, toResponsesFunctionCallItems(detected)...) output = append(output, toResponsesFunctionCallItems(detected.Calls)...)
} else { } else {
content := make([]any, 0, 2) content := make([]any, 0, 2)
if finalThinking != "" { if finalThinking != "" {

View File

@@ -71,6 +71,19 @@ func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, cont
} }
} }
func BuildResponsesTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
return map[string]any{
"type": "response.output_text.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"text": text,
}
}
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{ return map[string]any{
"type": "response.reasoning.delta", "type": "response.reasoning.delta",

View File

@@ -45,7 +45,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
} }
} }
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) { func TestBuildResponseObjectPromotesMixedProseToolPayloadToFunctionCall(t *testing.T) {
obj := BuildResponseObject( obj := BuildResponseObject(
"resp_test", "resp_test",
"gpt-4o", "gpt-4o",
@@ -57,20 +57,19 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
outputText, _ := obj["output_text"].(string) outputText, _ := obj["output_text"].(string)
if outputText != "" { if outputText != "" {
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText) t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
} }
output, _ := obj["output"].([]any) output, _ := obj["output"].([]any)
if len(output) != 1 { if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"]) t.Fatalf("expected one function_call output item, got %#v", obj["output"])
} }
first, _ := output[0].(map[string]any) first, _ := output[0].(map[string]any)
if first["type"] != "function_call" { if first["type"] != "function_call" {
t.Fatalf("expected first output type function_call, got %#v", first["type"]) t.Fatalf("expected function_call output type, got %#v", first["type"])
} }
} }
func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) { func TestBuildResponseObjectPromotesFencedToolPayloadToFunctionCall(t *testing.T) {
obj := BuildResponseObject( obj := BuildResponseObject(
"resp_test", "resp_test",
"gpt-4o", "gpt-4o",
@@ -81,16 +80,16 @@ func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
) )
outputText, _ := obj["output_text"].(string) outputText, _ := obj["output_text"].(string)
if outputText == "" { if outputText != "" {
t.Fatalf("expected output_text preserved for fenced example") t.Fatalf("expected output_text hidden for fenced tool payload, got %q", outputText)
} }
output, _ := obj["output"].([]any) output, _ := obj["output"].([]any)
if len(output) != 1 { if len(output) != 1 {
t.Fatalf("expected one message output item, got %#v", obj["output"]) t.Fatalf("expected one function_call output item, got %#v", obj["output"])
} }
first, _ := output[0].(map[string]any) first, _ := output[0].(map[string]any)
if first["type"] != "message" { if first["type"] != "function_call" {
t.Fatalf("expected message output type, got %#v", first["type"]) t.Fatalf("expected function_call output type, got %#v", first["type"])
} }
} }
@@ -127,7 +126,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
} }
} }
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { func TestBuildResponseObjectIgnoresToolCallFromThinkingChannel(t *testing.T) {
obj := BuildResponseObject( obj := BuildResponseObject(
"resp_test", "resp_test",
"gpt-4o", "gpt-4o",
@@ -139,10 +138,10 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
output, _ := obj["output"].([]any) output, _ := obj["output"].([]any)
if len(output) != 1 { if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"]) t.Fatalf("expected one message output item, got %#v", obj["output"])
} }
first, _ := output[0].(map[string]any) first, _ := output[0].(map[string]any)
if first["type"] != "function_call" { if first["type"] != "message" {
t.Fatalf("expected output function_call, got %#v", first["type"]) t.Fatalf("expected output message, got %#v", first["type"])
} }
} }

View File

@@ -10,8 +10,10 @@ const {
} = require('./sse_parse'); } = require('./sse_parse');
const { const {
resolveToolcallPolicy, resolveToolcallPolicy,
formatIncrementalToolCallDeltas,
normalizePreparedToolNames, normalizePreparedToolNames,
boolDefaultTrue, boolDefaultTrue,
filterIncrementalToolCallDeltasByAllowed,
} = require('./toolcall_policy'); } = require('./toolcall_policy');
const { const {
estimateTokens, estimateTokens,
@@ -82,7 +84,9 @@ module.exports.__test = {
shouldSkipPath, shouldSkipPath,
asString, asString,
resolveToolcallPolicy, resolveToolcallPolicy,
formatIncrementalToolCallDeltas,
normalizePreparedToolNames, normalizePreparedToolNames,
boolDefaultTrue, boolDefaultTrue,
filterIncrementalToolCallDeltasByAllowed,
estimateTokens, estimateTokens,
}; };

View File

@@ -68,6 +68,47 @@ function formatIncrementalToolCallDeltas(deltas, idStore) {
return out; return out;
} }
function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenNames) {
if (!Array.isArray(deltas) || deltas.length === 0) {
return [];
}
const seen = seenNames instanceof Map ? seenNames : new Map();
const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== ''));
if (allowed.size === 0) {
for (const d of deltas) {
if (d && typeof d === 'object' && asString(d.name)) {
const index = Number.isInteger(d.index) ? d.index : 0;
seen.set(index, '__blocked__');
}
}
return [];
}
const out = [];
for (const d of deltas) {
if (!d || typeof d !== 'object') {
continue;
}
const index = Number.isInteger(d.index) ? d.index : 0;
const name = asString(d.name);
if (name) {
if (!allowed.has(name)) {
seen.set(index, '__blocked__');
continue;
}
seen.set(index, name);
out.push(d);
continue;
}
const existing = asString(seen.get(index));
if (!existing || existing === '__blocked__') {
continue;
}
out.push(d);
}
return out;
}
function ensureStreamToolCallID(idStore, index) { function ensureStreamToolCallID(idStore, index) {
const key = Number.isInteger(index) ? index : 0; const key = Number.isInteger(index) ? index : 0;
const existing = idStore.get(key); const existing = idStore.get(key);
@@ -104,4 +145,5 @@ module.exports = {
normalizePreparedToolNames, normalizePreparedToolNames,
boolDefaultTrue, boolDefaultTrue,
formatIncrementalToolCallDeltas, formatIncrementalToolCallDeltas,
filterIncrementalToolCallDeltasByAllowed,
}; };

View File

@@ -1,34 +1,22 @@
'use strict'; 'use strict';
const { const {
extractToolNames,
createToolSieveState, createToolSieveState,
processToolSieveChunk, processToolSieveChunk,
flushToolSieve, flushToolSieve,
parseToolCalls, parseStandaloneToolCalls,
formatOpenAIStreamToolCalls, formatOpenAIStreamToolCalls,
} = require('../helpers/stream-tool-sieve'); } = require('../helpers/stream-tool-sieve');
const { const { BASE_HEADERS } = require('../shared/deepseek-constants');
BASE_HEADERS, const { writeOpenAIError } = require('./error_shape');
} = require('../shared/deepseek-constants'); const { parseChunkForContent, isCitation } = require('./sse_parse');
const { buildUsage } = require('./token_usage');
const {
writeOpenAIError,
} = require('./error_shape');
const {
parseChunkForContent,
isCitation,
} = require('./sse_parse');
const {
buildUsage,
} = require('./token_usage');
const { const {
resolveToolcallPolicy, resolveToolcallPolicy,
formatIncrementalToolCallDeltas, formatIncrementalToolCallDeltas,
filterIncrementalToolCallDeltasByAllowed,
} = require('./toolcall_policy'); } = require('./toolcall_policy');
const { const { createChatCompletionEmitter } = require('./stream_emitter');
createChatCompletionEmitter,
} = require('./stream_emitter');
const { const {
asString, asString,
isAbortError, isAbortError,
@@ -58,6 +46,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
const searchEnabled = toBool(prep.body.search_enabled); const searchEnabled = toBool(prep.body.search_enabled);
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
const toolNames = toolPolicy.toolNames; const toolNames = toolPolicy.toolNames;
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
writeOpenAIError(res, 500, 'invalid vercel prepare response'); writeOpenAIError(res, 500, 'invalid vercel prepare response');
@@ -130,10 +119,10 @@ async function handleVercelStream(req, res, rawBody, payload) {
let thinkingText = ''; let thinkingText = '';
let outputText = ''; let outputText = '';
const toolSieveEnabled = toolPolicy.toolSieveEnabled; const toolSieveEnabled = toolPolicy.toolSieveEnabled;
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
const toolSieveState = createToolSieveState(); const toolSieveState = createToolSieveState();
let toolCallsEmitted = false; let toolCallsEmitted = false;
const streamToolCallIDs = new Map(); const streamToolCallIDs = new Map();
const streamToolNames = new Map();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
reader = completionRes.body.getReader(); reader = completionRes.body.getReader();
let buffered = ''; let buffered = '';
@@ -155,13 +144,18 @@ async function handleVercelStream(req, res, rawBody, payload) {
await releaseLease(); await releaseLease();
return; return;
} }
const detected = parseToolCalls(outputText, toolNames); const detected = parseStandaloneToolCalls(outputText, toolNames);
if (detected.length > 0 && !toolCallsEmitted) { if (detected.length > 0 && !toolCallsEmitted) {
toolCallsEmitted = true; toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected, streamToolCallIDs) });
} else if (toolSieveEnabled) { } else if (toolSieveEnabled) {
const tailEvents = flushToolSieve(toolSieveState, toolNames); const tailEvents = flushToolSieve(toolSieveState, toolNames);
for (const evt of tailEvents) { for (const evt of tailEvents) {
if (evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0) {
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
continue;
}
if (evt.text) { if (evt.text) {
sendDeltaFrame({ content: evt.text }); sendDeltaFrame({ content: evt.text });
} }
@@ -252,17 +246,21 @@ async function handleVercelStream(req, res, rawBody, payload) {
} }
const events = processToolSieveChunk(toolSieveState, p.text, toolNames); const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
for (const evt of events) { for (const evt of events) {
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { if (evt.type === 'tool_call_deltas') {
if (!emitEarlyToolDeltas) { if (!emitEarlyToolDeltas) {
continue; continue;
} }
toolCallsEmitted = true; const filtered = filterIncrementalToolCallDeltasByAllowed(evt.deltas, toolNames, streamToolNames);
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); const formatted = formatIncrementalToolCallDeltas(filtered, streamToolCallIDs);
if (formatted.length > 0) {
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatted });
}
continue; continue;
} }
if (evt.type === 'tool_calls') { if (evt.type === 'tool_calls') {
toolCallsEmitted = true; toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
continue; continue;
} }
if (evt.text) { if (evt.text) {

View File

@@ -2,13 +2,13 @@
const crypto = require('crypto'); const crypto = require('crypto');
function formatOpenAIStreamToolCalls(calls) { function formatOpenAIStreamToolCalls(calls, idStore) {
if (!Array.isArray(calls) || calls.length === 0) { if (!Array.isArray(calls) || calls.length === 0) {
return []; return [];
} }
return calls.map((c, idx) => ({ return calls.map((c, idx) => ({
index: idx, index: idx,
id: `call_${newCallID()}`, id: ensureStreamToolCallID(idStore, idx),
type: 'function', type: 'function',
function: { function: {
name: c.name, name: c.name,
@@ -17,6 +17,20 @@ function formatOpenAIStreamToolCalls(calls) {
})); }));
} }
function ensureStreamToolCallID(idStore, index) {
if (!(idStore instanceof Map)) {
return `call_${newCallID()}`;
}
const key = Number.isInteger(index) ? index : 0;
const existing = idStore.get(key);
if (existing) {
return existing;
}
const next = `call_${newCallID()}`;
idStore.set(key, next);
return next;
}
function newCallID() { function newCallID() {
if (typeof crypto.randomUUID === 'function') { if (typeof crypto.randomUUID === 'function') {
return crypto.randomUUID().replace(/-/g, ''); return crypto.randomUUID().replace(/-/g, '');

View File

@@ -1,226 +0,0 @@
'use strict';
const {
looksLikeToolExampleContext,
insideCodeFence,
} = require('./state');
const {
findObjectFieldValueStart,
parseJSONStringLiteral,
skipSpaces,
} = require('./jsonscan');
function buildIncrementalToolDeltas(state) {
const captured = state.capture || '';
if (!captured) {
return [];
}
if (looksLikeToolExampleContext(state.recentTextTail)) {
return [];
}
const lower = captured.toLowerCase();
const keyIdx = lower.indexOf('tool_calls');
if (keyIdx < 0) {
return [];
}
const start = captured.slice(0, keyIdx).lastIndexOf('{');
if (start < 0) {
return [];
}
if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
return [];
}
const callStart = findFirstToolCallObjectStart(captured, keyIdx);
if (callStart < 0) {
return [];
}
const deltas = [];
if (!state.toolName) {
const name = extractToolCallName(captured, callStart);
if (!name) {
return [];
}
state.toolName = name;
}
if (state.toolArgsStart < 0) {
const args = findToolCallArgsStart(captured, callStart);
if (args) {
state.toolArgsString = Boolean(args.stringMode);
state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
state.toolArgsSent = state.toolArgsStart;
}
}
if (!state.toolNameSent) {
if (state.toolArgsStart < 0) {
return [];
}
state.toolNameSent = true;
deltas.push({ index: 0, name: state.toolName });
}
if (state.toolArgsStart < 0 || state.toolArgsDone) {
return deltas;
}
const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
if (!progress) {
return deltas;
}
if (progress.end > state.toolArgsSent) {
deltas.push({
index: 0,
arguments: captured.slice(state.toolArgsSent, progress.end),
});
state.toolArgsSent = progress.end;
}
if (progress.complete) {
state.toolArgsDone = true;
}
return deltas;
}
function findFirstToolCallObjectStart(text, keyIdx) {
const arrStart = findToolCallsArrayStart(text, keyIdx);
if (arrStart < 0) {
return -1;
}
const i = skipSpaces(text, arrStart + 1);
if (i >= text.length || text[i] !== '{') {
return -1;
}
return i;
}
function findToolCallsArrayStart(text, keyIdx) {
let i = keyIdx + 'tool_calls'.length;
while (i < text.length && text[i] !== ':') {
i += 1;
}
if (i >= text.length) {
return -1;
}
i = skipSpaces(text, i + 1);
if (i >= text.length || text[i] !== '[') {
return -1;
}
return i;
}
function extractToolCallName(text, callStart) {
let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
if (valueStart < 0 || text[valueStart] !== '"') {
const fnStart = findFunctionObjectStart(text, callStart);
if (fnStart < 0) {
return '';
}
valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
if (valueStart < 0 || text[valueStart] !== '"') {
return '';
}
}
const parsed = parseJSONStringLiteral(text, valueStart);
if (!parsed) {
return '';
}
return parsed.value;
}
function findToolCallArgsStart(text, callStart) {
const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
let valueStart = findObjectFieldValueStart(text, callStart, keys);
if (valueStart < 0) {
const fnStart = findFunctionObjectStart(text, callStart);
if (fnStart < 0) {
return null;
}
valueStart = findObjectFieldValueStart(text, fnStart, keys);
if (valueStart < 0) {
return null;
}
}
if (valueStart >= text.length) {
return null;
}
const ch = text[valueStart];
if (ch === '{' || ch === '[') {
return { start: valueStart, stringMode: false };
}
if (ch === '"') {
return { start: valueStart, stringMode: true };
}
return null;
}
function scanToolCallArgsProgress(text, start, stringMode) {
if (start < 0 || start > text.length) {
return null;
}
if (stringMode) {
let escaped = false;
for (let i = start; i < text.length; i += 1) {
const ch = text[i];
if (escaped) {
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === '"') {
return { end: i, complete: true };
}
}
return { end: text.length, complete: false };
}
if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
return null;
}
let depth = 0;
let quote = '';
let escaped = false;
for (let i = start; i < text.length; i += 1) {
const ch = text[i];
if (quote) {
if (escaped) {
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === quote) {
quote = '';
}
continue;
}
if (ch === '"' || ch === "'") {
quote = ch;
continue;
}
if (ch === '{' || ch === '[') {
depth += 1;
continue;
}
if (ch === '}' || ch === ']') {
depth -= 1;
if (depth === 0) {
return { end: i + 1, complete: true };
}
}
}
return { end: text.length, complete: false };
}
function findFunctionObjectStart(text, callStart) {
const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
return -1;
}
return valueStart;
}
module.exports = {
buildIncrementalToolDeltas,
};

View File

@@ -10,7 +10,9 @@ const {
const { const {
extractToolNames, extractToolNames,
parseToolCalls, parseToolCalls,
parseToolCallsDetailed,
parseStandaloneToolCalls, parseStandaloneToolCalls,
parseStandaloneToolCallsDetailed,
} = require('./parse'); } = require('./parse');
const { const {
formatOpenAIStreamToolCalls, formatOpenAIStreamToolCalls,
@@ -22,6 +24,8 @@ module.exports = {
processToolSieveChunk, processToolSieveChunk,
flushToolSieve, flushToolSieve,
parseToolCalls, parseToolCalls,
parseToolCallsDetailed,
parseStandaloneToolCalls, parseStandaloneToolCalls,
parseStandaloneToolCallsDetailed,
formatOpenAIStreamToolCalls, formatOpenAIStreamToolCalls,
}; };

View File

@@ -1,14 +1,16 @@
'use strict'; 'use strict';
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
const { const {
toStringSafe, toStringSafe,
looksLikeToolExampleContext,
} = require('./state'); } = require('./state');
const { const {
extractJSONObjectFrom, buildToolCallCandidates,
} = require('./jsonscan'); parseToolCallsPayload,
parseMarkupToolCalls,
parseTextKVToolCalls,
} = require('./parse_payload');
const TOOL_NAME_LOOSE_PATTERN = /[^a-z0-9]+/g;
function extractToolNames(tools) { function extractToolNames(tools) {
if (!Array.isArray(tools) || tools.length === 0) { if (!Array.isArray(tools) || tools.length === 0) {
@@ -29,245 +31,206 @@ function extractToolNames(tools) {
} }
function parseToolCalls(text, toolNames) { function parseToolCalls(text, toolNames) {
if (!toStringSafe(text)) { return parseToolCallsDetailed(text, toolNames).calls;
return []; }
function parseToolCallsDetailed(text, toolNames) {
const result = emptyParseResult();
const normalized = toStringSafe(text);
if (!normalized) {
return result;
} }
const sanitized = stripFencedCodeBlocks(text); result.sawToolCallSyntax = looksLikeToolCallSyntax(normalized);
if (!toStringSafe(sanitized)) {
return []; const candidates = buildToolCallCandidates(normalized);
}
const candidates = buildToolCallCandidates(sanitized);
let parsed = []; let parsed = [];
for (const c of candidates) { for (const c of candidates) {
parsed = parseToolCallsPayload(c); parsed = parseToolCallsPayload(c);
if (parsed.length === 0) {
parsed = parseMarkupToolCalls(c);
}
if (parsed.length === 0) {
parsed = parseTextKVToolCalls(c);
}
if (parsed.length > 0) {
result.sawToolCallSyntax = true;
break;
}
}
if (parsed.length === 0) {
parsed = parseMarkupToolCalls(normalized);
if (parsed.length === 0) {
parsed = parseTextKVToolCalls(normalized);
if (parsed.length === 0) {
return result;
}
}
result.sawToolCallSyntax = true;
}
const filtered = filterToolCallsDetailed(parsed, toolNames);
result.calls = filtered.calls;
result.rejectedToolNames = filtered.rejectedToolNames;
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
return result;
}
function parseStandaloneToolCalls(text, toolNames) {
return parseStandaloneToolCallsDetailed(text, toolNames).calls;
}
function parseStandaloneToolCallsDetailed(text, toolNames) {
const result = emptyParseResult();
const trimmed = toStringSafe(text);
if (!trimmed) {
return result;
}
result.sawToolCallSyntax = looksLikeToolCallSyntax(trimmed);
const candidates = buildToolCallCandidates(trimmed);
let parsed = [];
for (const c of candidates) {
parsed = parseToolCallsPayload(c);
if (parsed.length === 0) {
parsed = parseMarkupToolCalls(c);
}
if (parsed.length === 0) {
parsed = parseTextKVToolCalls(c);
}
if (parsed.length > 0) { if (parsed.length > 0) {
break; break;
} }
} }
if (parsed.length === 0) { if (parsed.length === 0) {
return []; parsed = parseMarkupToolCalls(trimmed);
} if (parsed.length === 0) {
return filterToolCalls(parsed, toolNames); parsed = parseTextKVToolCalls(trimmed);
} if (parsed.length === 0) {
return result;
function stripFencedCodeBlocks(text) {
const t = typeof text === 'string' ? text : '';
if (!t) {
return '';
}
return t.replace(/```[\s\S]*?```/g, ' ');
}
function parseStandaloneToolCalls(text, toolNames) {
const trimmed = toStringSafe(text);
if (!trimmed) {
return [];
}
if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
return [];
}
if (looksLikeToolExampleContext(trimmed)) {
return [];
}
const candidates = [trimmed];
if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
if (m && m[1]) {
candidates.push(toStringSafe(m[1]));
}
}
for (const candidate of candidates) {
const c = toStringSafe(candidate);
if (!c) {
continue;
}
if (!c.startsWith('{') && !c.startsWith('[')) {
continue;
}
const parsed = parseToolCallsPayload(c);
if (parsed.length > 0) {
return filterToolCalls(parsed, toolNames);
}
}
return [];
}
function buildToolCallCandidates(text) {
const trimmed = toStringSafe(text);
const candidates = [trimmed];
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
for (const block of fenced) {
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
if (m && m[1]) {
candidates.push(toStringSafe(m[1]));
}
}
for (const candidate of extractToolCallObjects(trimmed)) {
candidates.push(toStringSafe(candidate));
}
const first = trimmed.indexOf('{');
const last = trimmed.lastIndexOf('}');
if (first >= 0 && last > first) {
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
}
const m = trimmed.match(TOOL_CALL_PATTERN);
if (m && m[1]) {
candidates.push(`{"tool_calls":[${m[1]}]}`);
}
return [...new Set(candidates.filter(Boolean))];
}
function extractToolCallObjects(text) {
const raw = toStringSafe(text);
if (!raw) {
return [];
}
const lower = raw.toLowerCase();
const out = [];
let offset = 0;
// eslint-disable-next-line no-constant-condition
while (true) {
let idx = lower.indexOf('tool_calls', offset);
if (idx < 0) {
break;
}
let start = raw.slice(0, idx).lastIndexOf('{');
while (start >= 0) {
const obj = extractJSONObjectFrom(raw, start);
if (obj.ok) {
out.push(raw.slice(start, obj.end).trim());
offset = obj.end;
idx = -1;
break;
}
start = raw.slice(0, start).lastIndexOf('{');
}
if (idx >= 0) {
offset = idx + 'tool_calls'.length;
}
}
return out;
}
function parseToolCallsPayload(payload) {
let decoded;
try {
decoded = JSON.parse(payload);
} catch (_err) {
return [];
}
if (Array.isArray(decoded)) {
return parseToolCallList(decoded);
}
if (!decoded || typeof decoded !== 'object') {
return [];
}
if (decoded.tool_calls) {
return parseToolCallList(decoded.tool_calls);
}
const one = parseToolCallItem(decoded);
return one ? [one] : [];
}
function parseToolCallList(v) {
if (!Array.isArray(v)) {
return [];
}
const out = [];
for (const item of v) {
if (!item || typeof item !== 'object') {
continue;
}
const one = parseToolCallItem(item);
if (one) {
out.push(one);
}
}
return out;
}
function parseToolCallItem(m) {
let name = toStringSafe(m.name);
let inputRaw = m.input;
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
const fn = m.function && typeof m.function === 'object' ? m.function : null;
if (fn) {
if (!name) {
name = toStringSafe(fn.name);
}
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
inputRaw = fn.arguments;
hasInput = true;
}
}
if (!hasInput) {
for (const k of ['arguments', 'args', 'parameters', 'params']) {
if (Object.prototype.hasOwnProperty.call(m, k)) {
inputRaw = m[k];
hasInput = true;
break;
} }
} }
} }
if (!name) {
return null; result.sawToolCallSyntax = true;
} const filtered = filterToolCallsDetailed(parsed, toolNames);
result.calls = filtered.calls;
result.rejectedToolNames = filtered.rejectedToolNames;
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
return result;
}
function emptyParseResult() {
return { return {
name, calls: [],
input: parseToolCallInput(inputRaw), sawToolCallSyntax: false,
rejectedByPolicy: false,
rejectedToolNames: [],
}; };
} }
function parseToolCallInput(v) { function filterToolCallsDetailed(parsed, toolNames) {
if (v == null) { const sourceNames = Array.isArray(toolNames) ? toolNames : [];
return {}; const allowed = new Set();
} const allowedCanonical = new Map();
if (typeof v === 'string') { for (const item of sourceNames) {
const raw = toStringSafe(v); const name = toStringSafe(item);
if (!raw) { if (!name) {
return {}; continue;
} }
try { allowed.add(name);
const parsed = JSON.parse(raw); const lower = name.toLowerCase();
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { if (!allowedCanonical.has(lower)) {
return parsed; allowedCanonical.set(lower, name);
}
return { _raw: raw };
} catch (_err) {
return { _raw: raw };
} }
} }
if (typeof v === 'object' && !Array.isArray(v)) {
return v;
}
try {
const parsed = JSON.parse(JSON.stringify(v));
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
return parsed;
}
} catch (_err) {
return {};
}
return {};
}
function filterToolCalls(parsed, toolNames) { if (allowed.size === 0) {
const allowed = new Set((toolNames || []).filter(Boolean)); const rejected = [];
const out = []; const seen = new Set();
for (const tc of parsed) {
if (!tc || !tc.name) {
continue;
}
if (seen.has(tc.name)) {
continue;
}
seen.add(tc.name);
rejected.push(tc.name);
}
return { calls: [], rejectedToolNames: rejected };
}
const calls = [];
const rejected = [];
const seenRejected = new Set();
for (const tc of parsed) { for (const tc of parsed) {
if (!tc || !tc.name) { if (!tc || !tc.name) {
continue; continue;
} }
if (allowed.size > 0 && !allowed.has(tc.name)) { let matchedName = '';
if (allowed.has(tc.name)) {
matchedName = tc.name;
} else {
matchedName = resolveAllowedToolName(tc.name, allowed, allowedCanonical);
}
if (!matchedName) {
if (!seenRejected.has(tc.name)) {
seenRejected.add(tc.name);
rejected.push(tc.name);
}
continue; continue;
} }
out.push({ name: tc.name, input: tc.input || {} }); calls.push({
name: matchedName,
input: tc.input && typeof tc.input === 'object' && !Array.isArray(tc.input) ? tc.input : {},
});
} }
return out; return { calls, rejectedToolNames: rejected };
}
function resolveAllowedToolName(name, allowed, allowedCanonical) {
const normalizedName = toStringSafe(name).trim();
if (!normalizedName) {
return '';
}
if (allowed.has(normalizedName)) {
return normalizedName;
}
const lower = normalizedName.toLowerCase();
if (allowedCanonical.has(lower)) {
return allowedCanonical.get(lower);
}
const idx = lower.lastIndexOf('.');
if (idx >= 0 && idx < lower.length - 1) {
const tail = lower.slice(idx + 1);
if (allowedCanonical.has(tail)) {
return allowedCanonical.get(tail);
}
}
const loose = lower.replace(TOOL_NAME_LOOSE_PATTERN, '');
if (!loose) {
return '';
}
for (const [candidateLower, canonical] of allowedCanonical.entries()) {
if (candidateLower.replace(TOOL_NAME_LOOSE_PATTERN, '') === loose) {
return canonical;
}
}
return '';
}
function looksLikeToolCallSyntax(text) {
const lower = toStringSafe(text).toLowerCase();
return lower.includes('tool_calls')
|| lower.includes('<tool_call')
|| lower.includes('<function_call')
|| lower.includes('<invoke')
|| lower.includes('function.name:');
} }
module.exports = { module.exports = {
extractToolNames, extractToolNames,
parseToolCalls, parseToolCalls,
parseToolCallsDetailed,
parseStandaloneToolCalls, parseStandaloneToolCalls,
parseStandaloneToolCallsDetailed,
}; };

View File

@@ -0,0 +1,363 @@
'use strict';
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?(tool_call|function_call|invoke)\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
const TOOL_CALL_MARKUP_SELFCLOSE_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)\/>/gi;
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
const TOOL_CALL_MARKUP_ATTR_PATTERN = /(name|function|tool)\s*=\s*"([^"]+)"/i;
const TOOL_CALL_MARKUP_NAME_PATTERNS = [
/<(?:[a-z0-9_:-]+:)?name\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?name>/i,
/<(?:[a-z0-9_:-]+:)?function\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?function>/i,
];
const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
/<(?:[a-z0-9_:-]+:)?input\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?input>/i,
/<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?arguments>/i,
/<(?:[a-z0-9_:-]+:)?argument\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?argument>/i,
/<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameters>/i,
/<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/i,
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
];
const TEXT_KV_NAME_PATTERN = /function\.name:\s*([a-zA-Z0-9_.-]+)/gi;
const {
toStringSafe,
} = require('./state');
const {
extractJSONObjectFrom,
} = require('./jsonscan');
function stripFencedCodeBlocks(text) {
const t = typeof text === 'string' ? text : '';
if (!t) {
return '';
}
return t.replace(/```[\s\S]*?```/g, ' ');
}
function buildToolCallCandidates(text) {
const trimmed = toStringSafe(text);
const candidates = [trimmed];
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
for (const block of fenced) {
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
if (m && m[1]) {
candidates.push(toStringSafe(m[1]));
}
}
for (const candidate of extractToolCallObjects(trimmed)) {
candidates.push(toStringSafe(candidate));
}
const first = trimmed.indexOf('{');
const last = trimmed.lastIndexOf('}');
if (first >= 0 && last > first) {
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
}
const m = trimmed.match(TOOL_CALL_PATTERN);
if (m && m[1]) {
candidates.push(`{"tool_calls":[${m[1]}]}`);
}
return [...new Set(candidates.filter(Boolean))];
}
function extractToolCallObjects(text) {
const raw = toStringSafe(text);
if (!raw) {
return [];
}
const lower = raw.toLowerCase();
const out = [];
let offset = 0;
// eslint-disable-next-line no-constant-condition
while (true) {
let idx = lower.indexOf('tool_calls', offset);
if (idx < 0) {
break;
}
let start = raw.slice(0, idx).lastIndexOf('{');
while (start >= 0) {
const obj = extractJSONObjectFrom(raw, start);
if (obj.ok) {
out.push(raw.slice(start, obj.end).trim());
offset = obj.end;
idx = -1;
break;
}
start = raw.slice(0, start).lastIndexOf('{');
}
if (idx >= 0) {
offset = idx + 'tool_calls'.length;
}
}
return out;
}
function parseToolCallsPayload(payload) {
let decoded;
try {
decoded = JSON.parse(payload);
} catch (_err) {
return [];
}
if (Array.isArray(decoded)) {
return parseToolCallList(decoded);
}
if (!decoded || typeof decoded !== 'object') {
return [];
}
if (decoded.tool_calls) {
return parseToolCallList(decoded.tool_calls);
}
const one = parseToolCallItem(decoded);
return one ? [one] : [];
}
function parseMarkupToolCalls(text) {
const raw = toStringSafe(text).trim();
if (!raw) {
return [];
}
const out = [];
for (const m of raw.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
const parsed = parseMarkupSingleToolCall(toStringSafe(m[2]).trim(), toStringSafe(m[3]).trim());
if (parsed) {
out.push(parsed);
}
}
for (const m of raw.matchAll(TOOL_CALL_MARKUP_SELFCLOSE_PATTERN)) {
const parsed = parseMarkupSingleToolCall(toStringSafe(m[1]).trim(), '');
if (parsed) {
out.push(parsed);
}
}
return out;
}
function parseTextKVToolCalls(text) {
const raw = toStringSafe(text);
if (!raw) {
return [];
}
const out = [];
const matches = [...raw.matchAll(TEXT_KV_NAME_PATTERN)];
if (matches.length === 0) {
return out;
}
for (let i = 0; i < matches.length; i += 1) {
const match = matches[i];
const name = toStringSafe(match[1]).trim();
if (!name) {
continue;
}
const nameEnd = match.index + toStringSafe(match[0]).length;
const searchEnd = i + 1 < matches.length ? matches[i + 1].index : raw.length;
const searchArea = raw.slice(nameEnd, searchEnd);
const argIdx = searchArea.indexOf('function.arguments:');
if (argIdx < 0) {
continue;
}
const argStart = nameEnd + argIdx + 'function.arguments:'.length;
const bracePos = raw.slice(argStart, searchEnd).indexOf('{');
if (bracePos < 0) {
continue;
}
const objStart = argStart + bracePos;
const obj = extractJSONObjectFrom(raw, objStart);
if (!obj.ok) {
continue;
}
out.push({
name,
input: parseToolCallInput(raw.slice(objStart, obj.end)),
});
}
return out;
}
function parseMarkupSingleToolCall(attrs, inner) {
const embedded = parseToolCallsPayload(inner);
if (embedded.length > 0) {
return embedded[0];
}
let name = '';
const attrMatch = attrs.match(TOOL_CALL_MARKUP_ATTR_PATTERN);
if (attrMatch && attrMatch[2]) {
name = toStringSafe(attrMatch[2]).trim();
}
if (!name) {
name = stripTagText(findMarkupTagValue(inner, TOOL_CALL_MARKUP_NAME_PATTERNS));
}
if (!name) {
return null;
}
let input = {};
const argsRaw = findMarkupTagValue(inner, TOOL_CALL_MARKUP_ARGS_PATTERNS);
if (argsRaw) {
input = parseMarkupInput(argsRaw);
} else {
const kv = parseMarkupKVObject(inner);
if (Object.keys(kv).length > 0) {
input = kv;
}
}
return { name, input };
}
function parseMarkupInput(raw) {
const s = toStringSafe(raw).trim();
if (!s) {
return {};
}
const parsed = parseToolCallInput(s);
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) {
return parsed;
}
const kv = parseMarkupKVObject(s);
if (Object.keys(kv).length > 0) {
return kv;
}
return { _raw: stripTagText(s) };
}
function parseMarkupKVObject(text) {
const raw = toStringSafe(text).trim();
if (!raw) {
return {};
}
const out = {};
for (const m of raw.matchAll(TOOL_CALL_MARKUP_KV_PATTERN)) {
const key = toStringSafe(m[1]).trim();
if (!key) {
continue;
}
const valueRaw = stripTagText(m[2]);
if (!valueRaw) {
continue;
}
try {
out[key] = JSON.parse(valueRaw);
} catch (_err) {
out[key] = valueRaw;
}
}
return out;
}
function stripTagText(text) {
return toStringSafe(text).replace(/<[^>]+>/g, ' ').trim();
}
function findMarkupTagValue(text, patterns) {
const source = toStringSafe(text);
for (const p of patterns) {
const m = source.match(p);
if (m && m[1]) {
return toStringSafe(m[1]);
}
}
return '';
}
function parseToolCallList(v) {
if (!Array.isArray(v)) {
return [];
}
const out = [];
for (const item of v) {
if (!item || typeof item !== 'object') {
continue;
}
const one = parseToolCallItem(item);
if (one) {
out.push(one);
}
}
return out;
}
function parseToolCallItem(m) {
let name = toStringSafe(m.name);
let inputRaw = m.input;
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
const fn = m.function && typeof m.function === 'object' ? m.function : null;
if (fn) {
if (!name) {
name = toStringSafe(fn.name);
}
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
inputRaw = fn.arguments;
hasInput = true;
}
}
if (!hasInput) {
for (const k of ['arguments', 'args', 'parameters', 'params']) {
if (Object.prototype.hasOwnProperty.call(m, k)) {
inputRaw = m[k];
hasInput = true;
break;
}
}
}
if (!name) {
return null;
}
return {
name,
input: parseToolCallInput(inputRaw),
};
}
function parseToolCallInput(v) {
if (v == null) {
return {};
}
if (typeof v === 'string') {
const raw = toStringSafe(v);
if (!raw) {
return {};
}
try {
const parsed = JSON.parse(raw);
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
return parsed;
}
return { _raw: raw };
} catch (_err) {
return { _raw: raw };
}
}
if (typeof v === 'object' && !Array.isArray(v)) {
return v;
}
try {
const parsed = JSON.parse(JSON.stringify(v));
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
return parsed;
}
} catch (_err) {
return {};
}
return {};
}
module.exports = {
stripFencedCodeBlocks,
buildToolCallCandidates,
parseToolCallsPayload,
parseMarkupToolCalls,
parseTextKVToolCalls,
};

View File

@@ -1,16 +1,12 @@
'use strict'; 'use strict';
const { const {
TOOL_SIEVE_CAPTURE_LIMIT,
resetIncrementalToolState, resetIncrementalToolState,
noteText, noteText,
insideCodeFence, insideCodeFence,
} = require('./state'); } = require('./state');
const { const {
buildIncrementalToolDeltas, parseStandaloneToolCallsDetailed,
} = require('./incremental');
const {
parseStandaloneToolCalls,
} = require('./parse'); } = require('./parse');
const { const {
extractJSONObjectFrom, extractJSONObjectFrom,
@@ -24,64 +20,67 @@ function processToolSieveChunk(state, chunk, toolNames) {
state.pending += chunk; state.pending += chunk;
} }
const events = []; const events = [];
// eslint-disable-next-line no-constant-condition // eslint-disable-next-line no-constant-condition
while (true) { while (true) {
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
state.pendingToolRaw = '';
state.pendingToolCalls = [];
continue;
}
if (state.capturing) { if (state.capturing) {
if (state.pending) { if (state.pending) {
state.capture += state.pending; state.capture += state.pending;
state.pending = ''; state.pending = '';
} }
const deltas = buildIncrementalToolDeltas(state);
if (deltas.length > 0) {
events.push({ type: 'tool_call_deltas', deltas });
}
const consumed = consumeToolCapture(state, toolNames); const consumed = consumeToolCapture(state, toolNames);
if (!consumed.ready) { if (!consumed.ready) {
if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
noteText(state, state.capture);
events.push({ type: 'text', text: state.capture });
state.capture = '';
state.capturing = false;
resetIncrementalToolState(state);
continue;
}
break; break;
} }
const captured = state.capture;
state.capture = ''; state.capture = '';
state.capturing = false; state.capturing = false;
resetIncrementalToolState(state); resetIncrementalToolState(state);
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
state.pendingToolRaw = captured;
state.pendingToolCalls = consumed.calls;
if (consumed.suffix) {
state.pending = consumed.suffix + state.pending;
}
continue;
}
if (consumed.prefix) { if (consumed.prefix) {
noteText(state, consumed.prefix); noteText(state, consumed.prefix);
events.push({ type: 'text', text: consumed.prefix }); events.push({ type: 'text', text: consumed.prefix });
} }
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
events.push({ type: 'tool_calls', calls: consumed.calls });
}
if (consumed.suffix) { if (consumed.suffix) {
state.pending += consumed.suffix; state.pending += consumed.suffix;
} }
continue; continue;
} }
if (!state.pending) { const pending = state.pending || '';
if (!pending) {
break; break;
} }
const start = findToolSegmentStart(state.pending); const start = findToolSegmentStart(pending);
if (start >= 0) { if (start >= 0) {
const prefix = state.pending.slice(0, start); const prefix = pending.slice(0, start);
if (prefix) { if (prefix) {
noteText(state, prefix); noteText(state, prefix);
events.push({ type: 'text', text: prefix }); events.push({ type: 'text', text: prefix });
} }
state.capture = state.pending.slice(start);
state.pending = ''; state.pending = '';
state.capture += pending.slice(start);
state.capturing = true; state.capturing = true;
resetIncrementalToolState(state); resetIncrementalToolState(state);
continue; continue;
} }
const [safe, hold] = splitSafeContentForToolDetection(state.pending); const [safe, hold] = splitSafeContentForToolDetection(pending);
if (!safe) { if (!safe) {
break; break;
} }
@@ -97,6 +96,13 @@ function flushToolSieve(state, toolNames) {
return []; return [];
} }
const events = processToolSieveChunk(state, '', toolNames); const events = processToolSieveChunk(state, '', toolNames);
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
state.pendingToolRaw = '';
state.pendingToolCalls = [];
}
if (state.capturing) { if (state.capturing) {
const consumed = consumeToolCapture(state, toolNames); const consumed = consumeToolCapture(state, toolNames);
if (consumed.ready) { if (consumed.ready) {
@@ -119,11 +125,13 @@ function flushToolSieve(state, toolNames) {
state.capturing = false; state.capturing = false;
resetIncrementalToolState(state); resetIncrementalToolState(state);
} }
if (state.pending) { if (state.pending) {
noteText(state, state.pending); noteText(state, state.pending);
events.push({ type: 'text', text: state.pending }); events.push({ type: 'text', text: state.pending });
state.pending = ''; state.pending = '';
} }
return events; return events;
} }
@@ -160,43 +168,67 @@ function findToolSegmentStart(s) {
return -1; return -1;
} }
const lower = s.toLowerCase(); const lower = s.toLowerCase();
const keywords = ['tool_calls', 'function.name:', '[tool_call_history]'];
let offset = 0; let offset = 0;
// eslint-disable-next-line no-constant-condition // eslint-disable-next-line no-constant-condition
while (true) { while (true) {
const keyRel = lower.indexOf('tool_calls', offset); let bestKeyIdx = -1;
if (keyRel < 0) { let matchedKeyword = '';
for (const kw of keywords) {
const idx = lower.indexOf(kw, offset);
if (idx >= 0) {
if (bestKeyIdx < 0 || idx < bestKeyIdx) {
bestKeyIdx = idx;
matchedKeyword = kw;
}
}
}
if (bestKeyIdx < 0) {
return -1; return -1;
} }
const keyIdx = keyRel;
const keyIdx = bestKeyIdx;
const start = s.slice(0, keyIdx).lastIndexOf('{'); const start = s.slice(0, keyIdx).lastIndexOf('{');
const candidateStart = start >= 0 ? start : keyIdx; const candidateStart = start >= 0 ? start : keyIdx;
if (!insideCodeFence(s.slice(0, candidateStart))) { if (!insideCodeFence(s.slice(0, candidateStart))) {
return candidateStart; return candidateStart;
} }
offset = keyIdx + 'tool_calls'.length; offset = keyIdx + matchedKeyword.length;
} }
} }
function consumeToolCapture(state, toolNames) { function consumeToolCapture(state, toolNames) {
const captured = state.capture; const captured = state.capture || '';
if (!captured) { if (!captured) {
return { ready: false, prefix: '', calls: [], suffix: '' }; return { ready: false, prefix: '', calls: [], suffix: '' };
} }
const lower = captured.toLowerCase(); const lower = captured.toLowerCase();
const keyIdx = lower.indexOf('tool_calls');
let keyIdx = -1;
const keywords = ['tool_calls', 'function.name:', '[tool_call_history]'];
for (const kw of keywords) {
const idx = lower.indexOf(kw);
if (idx >= 0 && (keyIdx < 0 || idx < keyIdx)) {
keyIdx = idx;
}
}
if (keyIdx < 0) { if (keyIdx < 0) {
return { ready: false, prefix: '', calls: [], suffix: '' }; return { ready: false, prefix: '', calls: [], suffix: '' };
} }
const start = captured.slice(0, keyIdx).lastIndexOf('{'); const start = captured.slice(0, keyIdx).lastIndexOf('{');
if (start < 0) { const actualStart = start >= 0 ? start : keyIdx;
return { ready: false, prefix: '', calls: [], suffix: '' };
} const obj = extractJSONObjectFrom(captured, actualStart);
const obj = extractJSONObjectFrom(captured, start);
if (!obj.ok) { if (!obj.ok) {
return { ready: false, prefix: '', calls: [], suffix: '' }; return { ready: false, prefix: '', calls: [], suffix: '' };
} }
const prefixPart = captured.slice(0, start);
const prefixPart = captured.slice(0, actualStart);
const suffixPart = captured.slice(obj.end); const suffixPart = captured.slice(obj.end);
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
return { return {
ready: true, ready: true,
@@ -205,18 +237,10 @@ function consumeToolCapture(state, toolNames) {
suffix: '', suffix: '',
}; };
} }
const rawParsed = parseStandaloneToolCalls(captured.slice(start, obj.end), []);
const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); const parsed = parseStandaloneToolCallsDetailed(captured.slice(actualStart, obj.end), toolNames);
if (parsed.length === 0) { if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) {
if (rawParsed.length > 0 && Array.isArray(toolNames) && toolNames.length > 0) { if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) {
return {
ready: true,
prefix: prefixPart,
calls: [],
suffix: suffixPart,
};
}
if (state.toolNameSent) {
return { return {
ready: true, ready: true,
prefix: prefixPart, prefix: prefixPart,
@@ -231,26 +255,11 @@ function consumeToolCapture(state, toolNames) {
suffix: '', suffix: '',
}; };
} }
if (state.toolNameSent) {
if (parsed.length > 1) {
return {
ready: true,
prefix: prefixPart,
calls: parsed.slice(1),
suffix: suffixPart,
};
}
return {
ready: true,
prefix: prefixPart,
calls: [],
suffix: suffixPart,
};
}
return { return {
ready: true, ready: true,
prefix: prefixPart, prefix: prefixPart,
calls: parsed, calls: parsed.calls,
suffix: suffixPart, suffix: suffixPart,
}; };
} }

View File

@@ -1,6 +1,5 @@
'use strict'; 'use strict';
const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
function createToolSieveState() { function createToolSieveState() {
@@ -9,6 +8,9 @@ function createToolSieveState() {
capture: '', capture: '',
capturing: false, capturing: false,
recentTextTail: '', recentTextTail: '',
pendingToolRaw: '',
pendingToolCalls: [],
disableDeltas: false,
toolNameSent: false, toolNameSent: false,
toolName: '', toolName: '',
toolArgsStart: -1, toolArgsStart: -1,
@@ -19,6 +21,7 @@ function createToolSieveState() {
} }
function resetIncrementalToolState(state) { function resetIncrementalToolState(state) {
state.disableDeltas = false;
state.toolNameSent = false; state.toolNameSent = false;
state.toolName = ''; state.toolName = '';
state.toolArgsStart = -1; state.toolArgsStart = -1;
@@ -78,7 +81,6 @@ function toStringSafe(v) {
} }
module.exports = { module.exports = {
TOOL_SIEVE_CAPTURE_LIMIT,
TOOL_SIEVE_CONTEXT_TAIL_LIMIT, TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
createToolSieveState, createToolSieveState,
resetIncrementalToolState, resetIncrementalToolState,

View File

@@ -51,6 +51,9 @@ func MessagesPrepare(messages []map[string]any) string {
} }
func NormalizeContent(v any) string { func NormalizeContent(v any) string {
if v == nil {
return ""
}
switch x := v.(type) { switch x := v.(type) {
case string: case string:
return x return x
@@ -64,11 +67,11 @@ func NormalizeContent(v any) string {
typeStr, _ := m["type"].(string) typeStr, _ := m["type"].(string)
typeStr = strings.ToLower(strings.TrimSpace(typeStr)) typeStr = strings.ToLower(strings.TrimSpace(typeStr))
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
if txt, ok := m["text"].(string); ok { if txt, ok := m["text"].(string); ok && txt != "" {
parts = append(parts, txt) parts = append(parts, txt)
continue continue
} }
if txt, ok := m["content"].(string); ok { if txt, ok := m["content"].(string); ok && txt != "" {
parts = append(parts, txt) parts = append(parts, txt)
} }
} }

View File

@@ -0,0 +1,32 @@
package prompt
import "testing"
func TestNormalizeContentNilReturnsEmpty(t *testing.T) {
if got := NormalizeContent(nil); got != "" {
t.Fatalf("expected empty string for nil content, got %q", got)
}
}
func TestMessagesPrepareNilContentNoNullLiteral(t *testing.T) {
messages := []map[string]any{
{"role": "assistant", "content": nil},
{"role": "user", "content": "ok"},
}
got := MessagesPrepare(messages)
if got == "" {
t.Fatalf("expected non-empty output")
}
if got == "null" {
t.Fatalf("expected no null literal output, got %q", got)
}
}
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
got := NormalizeContent([]any{
map[string]any{"type": "text", "text": "", "content": "from-content"},
})
if got != "from-content" {
t.Fatalf("expected fallback to content when text is empty, got %q", got)
}
}

View File

@@ -57,16 +57,20 @@ func NewApp() *App {
r.Use(cors) r.Use(cors)
r.Use(timeout(0)) r.Use(timeout(0))
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) { healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"status":"ok"}`)) _, _ = w.Write([]byte(`{"status":"ok"}`))
}) }
r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) { readyzHandler := func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"status":"ready"}`)) _, _ = w.Write([]byte(`{"status":"ready"}`))
}) }
r.Get("/healthz", healthzHandler)
r.Head("/healthz", healthzHandler)
r.Get("/readyz", readyzHandler)
r.Head("/readyz", readyzHandler)
openai.RegisterRoutes(r, openaiHandler) openai.RegisterRoutes(r, openaiHandler)
claude.RegisterRoutes(r, claudeHandler) claude.RegisterRoutes(r, claudeHandler)
gemini.RegisterRoutes(r, geminiHandler) gemini.RegisterRoutes(r, geminiHandler)

View File

@@ -0,0 +1,20 @@
package server
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestHealthEndpointsSupportHEAD(t *testing.T) {
app := NewApp()
for _, path := range []string{"/healthz", "/readyz"} {
req := httptest.NewRequest(http.MethodHead, path, nil)
rec := httptest.NewRecorder()
app.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected %s HEAD status 200, got %d", path, rec.Code)
}
}
}

View File

@@ -17,6 +17,12 @@ func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
var m map[string]any var m map[string]any
_ = json.Unmarshal(resp.Body, &m) _ = json.Unmarshal(resp.Body, &m)
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body))) cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/healthz", Retryable: true})
if headErr != nil {
return headErr
}
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
return nil return nil
} }
@@ -29,6 +35,12 @@ func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
var m map[string]any var m map[string]any
_ = json.Unmarshal(resp.Body, &m) _ = json.Unmarshal(resp.Body, &m)
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body))) cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/readyz", Retryable: true})
if headErr != nil {
return headErr
}
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
return nil return nil
} }

View File

@@ -20,7 +20,7 @@ func buildToolCallCandidates(text string) []string {
} }
} }
// best-effort extraction around "tool_calls" key in mixed text payloads. // best-effort extraction around tool call keywords in mixed text payloads.
candidates = append(candidates, extractToolCallObjects(trimmed)...) candidates = append(candidates, extractToolCallObjects(trimmed)...)
// best-effort object slice: from first '{' to last '}' // best-effort object slice: from first '{' to last '}'
@@ -57,25 +57,65 @@ func extractToolCallObjects(text string) []string {
lower := strings.ToLower(text) lower := strings.ToLower(text)
out := []string{} out := []string{}
offset := 0 offset := 0
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
for { for {
idx := strings.Index(lower[offset:], "tool_calls") bestIdx := -1
if idx < 0 { matchedKeyword := ""
for _, kw := range keywords {
idx := strings.Index(lower[offset:], kw)
if idx >= 0 {
absIdx := offset + idx
if bestIdx < 0 || absIdx < bestIdx {
bestIdx = absIdx
matchedKeyword = kw
}
}
}
if bestIdx < 0 {
break break
} }
idx += offset
start := strings.LastIndex(text[:idx], "{") idx := bestIdx
for start >= 0 { // Avoid backtracking too far to prevent OOM on malicious or very long strings
searchLimit := idx - 2000
if searchLimit < offset {
searchLimit = offset
}
start := strings.LastIndex(text[searchLimit:idx], "{")
if start >= 0 {
start += searchLimit
}
if start < 0 {
offset = idx + len(matchedKeyword)
continue
}
foundObj := false
for start >= searchLimit {
candidate, end, ok := extractJSONObject(text, start) candidate, end, ok := extractJSONObject(text, start)
if ok { if ok {
// Move forward to avoid repeatedly matching the same object. // Move forward to avoid repeatedly matching the same object.
offset = end offset = end
out = append(out, strings.TrimSpace(candidate)) out = append(out, strings.TrimSpace(candidate))
foundObj = true
break break
} }
start = strings.LastIndex(text[:start], "{") // Try previous '{'
if start > searchLimit {
prevStart := strings.LastIndex(text[searchLimit:start], "{")
if prevStart >= 0 {
start = searchLimit + prevStart
continue
}
}
break
} }
if start < 0 {
offset = idx + len("tool_calls") if !foundObj {
offset = idx + len(matchedKeyword)
} }
} }
return out return out
@@ -88,7 +128,12 @@ func extractJSONObject(text string, start int) (string, int, bool) {
depth := 0 depth := 0
quote := byte(0) quote := byte(0)
escaped := false escaped := false
for i := start; i < len(text); i++ { // Limit scan length to avoid OOM on unclosed objects
maxLen := start + 50000
if maxLen > len(text) {
maxLen = len(text)
}
for i := start; i < maxLen; i++ {
ch := text[i] ch := text[i]
if quote != 0 { if quote != 0 {
if escaped { if escaped {

View File

@@ -0,0 +1,108 @@
package util
import (
"encoding/json"
"strings"
"unicode"
)
func parseToolCallInput(v any) map[string]any {
switch x := v.(type) {
case nil:
return map[string]any{}
case map[string]any:
return x
case string:
raw := strings.TrimSpace(x)
if raw == "" {
return map[string]any{}
}
var parsed map[string]any
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
repairPathLikeControlChars(parsed)
return parsed
}
// Try to repair invalid backslashes (common in Windows paths output by models)
repaired := repairInvalidJSONBackslashes(raw)
if repaired != raw {
if err := json.Unmarshal([]byte(repaired), &parsed); err == nil && parsed != nil {
repairPathLikeControlChars(parsed)
return parsed
}
}
// Try to repair loose JSON in string argument as well
repairedLoose := RepairLooseJSON(raw)
if repairedLoose != raw {
if err := json.Unmarshal([]byte(repairedLoose), &parsed); err == nil && parsed != nil {
repairPathLikeControlChars(parsed)
return parsed
}
}
return map[string]any{"_raw": raw}
default:
b, err := json.Marshal(x)
if err != nil {
return map[string]any{}
}
var parsed map[string]any
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
return parsed
}
return map[string]any{}
}
}
func repairPathLikeControlChars(m map[string]any) {
for k, v := range m {
switch vv := v.(type) {
case map[string]any:
repairPathLikeControlChars(vv)
case []any:
for _, item := range vv {
if child, ok := item.(map[string]any); ok {
repairPathLikeControlChars(child)
}
}
case string:
if isPathLikeKey(k) && containsControlRune(vv) {
m[k] = escapeControlRunes(vv)
}
}
}
}
func isPathLikeKey(key string) bool {
k := strings.ToLower(strings.TrimSpace(key))
return strings.Contains(k, "path") || strings.Contains(k, "file")
}
func containsControlRune(s string) bool {
for _, r := range s {
if unicode.IsControl(r) {
return true
}
}
return false
}
func escapeControlRunes(s string) string {
var b strings.Builder
b.Grow(len(s) + 8)
for _, r := range s {
switch r {
case '\b':
b.WriteString(`\b`)
case '\f':
b.WriteString(`\f`)
case '\n':
b.WriteString(`\n`)
case '\r':
b.WriteString(`\r`)
case '\t':
b.WriteString(`\t`)
default:
b.WriteRune(r)
}
}
return b.String()
}

View File

@@ -0,0 +1,79 @@
package util
import (
"regexp"
"strings"
)
func repairInvalidJSONBackslashes(s string) string {
if !strings.Contains(s, "\\") {
return s
}
var out strings.Builder
out.Grow(len(s) + 10)
runes := []rune(s)
for i := 0; i < len(runes); i++ {
if runes[i] == '\\' {
if i+1 < len(runes) {
next := runes[i+1]
switch next {
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't':
out.WriteRune('\\')
out.WriteRune(next)
i++
continue
case 'u':
if i+5 < len(runes) {
isHex := true
for j := 1; j <= 4; j++ {
r := runes[i+1+j]
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
isHex = false
break
}
}
if isHex {
out.WriteRune('\\')
out.WriteRune('u')
for j := 1; j <= 4; j++ {
out.WriteRune(runes[i+1+j])
}
i += 5
continue
}
}
}
}
// Not a valid escape sequence, double it
out.WriteString("\\\\")
} else {
out.WriteRune(runes[i])
}
}
return out.String()
}
var unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`)
// missingArrayBracketsPattern identifies a sequence of two or more JSON objects separated by commas
// that immediately follow a colon, which indicates a missing array bracket `[` `]`.
// E.g., "key": {"a": 1}, {"b": 2} -> "key": [{"a": 1}, {"b": 2}]
// NOTE: The pattern uses (?:[^{}]|\{[^{}]*\})* to support single-level nested {} objects,
// which handles cases like {"content": "x", "input": {"q": "y"}}
var missingArrayBracketsPattern = regexp.MustCompile(`(:\s*)(\{(?:[^{}]|\{[^{}]*\})*\}(?:\s*,\s*\{(?:[^{}]|\{[^{}]*\})*\})+)`)
func RepairLooseJSON(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return s
}
// 1. Replace unquoted keys: {key: -> {"key":
s = unquotedKeyPattern.ReplaceAllString(s, `$1"$2":`)
// 2. Heuristic: Fix missing array brackets for list of objects
// e.g., : {obj1}, {obj2} -> : [{obj1}, {obj2}]
// This specifically addresses DeepSeek's "list hallucination"
s = missingArrayBracketsPattern.ReplaceAllString(s, `$1[$2]`)
return s
}

View File

@@ -0,0 +1,161 @@
package util
import (
"encoding/json"
"regexp"
"strings"
)
var toolCallMarkupTagNames = []string{"tool_call", "function_call", "invoke"}
var toolCallMarkupTagPatternByName = map[string]*regexp.Regexp{
"tool_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?tool_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?tool_call>`),
"function_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?function_call>`),
"invoke": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?invoke>`),
}
var toolCallMarkupSelfClosingPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)/>`)
var toolCallMarkupKVPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)>`)
var toolCallMarkupAttrPattern = regexp.MustCompile(`(?is)(name|function|tool)\s*=\s*"([^"]+)"`)
var anyTagPattern = regexp.MustCompile(`(?is)<[^>]+>`)
var toolCallMarkupNameTagNames = []string{"name", "function"}
var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
"name": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?name>`),
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
}
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
"arguments": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?arguments>`),
"argument": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?argument\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?argument>`),
"parameters": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameters>`),
"parameter": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameter>`),
"args": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?args\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?args>`),
"params": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?params\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?params>`),
}
func parseMarkupToolCalls(text string) []ParsedToolCall {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil
}
out := make([]ParsedToolCall, 0)
for _, tagName := range toolCallMarkupTagNames {
pattern := toolCallMarkupTagPatternByName[tagName]
for _, m := range pattern.FindAllStringSubmatch(trimmed, -1) {
if len(m) < 3 {
continue
}
attrs := strings.TrimSpace(m[1])
inner := strings.TrimSpace(m[2])
if parsed := parseMarkupSingleToolCall(attrs, inner); parsed.Name != "" {
out = append(out, parsed)
}
}
}
for _, m := range toolCallMarkupSelfClosingPattern.FindAllStringSubmatch(trimmed, -1) {
if len(m) < 2 {
continue
}
if parsed := parseMarkupSingleToolCall(strings.TrimSpace(m[1]), ""); parsed.Name != "" {
out = append(out, parsed)
}
}
if len(out) == 0 {
return nil
}
return out
}
func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
if parsed := parseToolCallsPayload(inner); len(parsed) > 0 {
return parsed[0]
}
name := ""
if m := toolCallMarkupAttrPattern.FindStringSubmatch(attrs); len(m) >= 3 {
name = strings.TrimSpace(m[2])
}
if name == "" {
name = findMarkupTagValue(inner, toolCallMarkupNameTagNames, toolCallMarkupNamePatternByTag)
}
if name == "" {
return ParsedToolCall{}
}
input := map[string]any{}
if argsRaw := findMarkupTagValue(inner, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
input = parseMarkupInput(argsRaw)
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
input = kv
}
return ParsedToolCall{Name: name, Input: input}
}
func parseMarkupInput(raw string) map[string]any {
raw = strings.TrimSpace(raw)
if raw == "" {
return map[string]any{}
}
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
return parsed
}
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
return kv
}
return map[string]any{"_raw": stripTagText(raw)}
}
func parseMarkupKVObject(text string) map[string]any {
matches := toolCallMarkupKVPattern.FindAllStringSubmatch(strings.TrimSpace(text), -1)
if len(matches) == 0 {
return nil
}
out := map[string]any{}
for _, m := range matches {
if len(m) < 4 {
continue
}
key := strings.TrimSpace(m[1])
endKey := strings.TrimSpace(m[3])
if key == "" {
continue
}
if !strings.EqualFold(key, endKey) {
continue
}
value := strings.TrimSpace(stripTagText(m[2]))
if value == "" {
continue
}
var jsonValue any
if json.Unmarshal([]byte(value), &jsonValue) == nil {
out[key] = jsonValue
continue
}
out[key] = value
}
if len(out) == 0 {
return nil
}
return out
}
func stripTagText(text string) string {
return strings.TrimSpace(anyTagPattern.ReplaceAllString(text, ""))
}
func findMarkupTagValue(text string, tagNames []string, patternByTag map[string]*regexp.Regexp) string {
for _, tag := range tagNames {
pattern := patternByTag[tag]
if pattern == nil {
continue
}
if m := pattern.FindStringSubmatch(text); len(m) >= 2 {
value := strings.TrimSpace(m[1])
if value != "" {
return value
}
}
}
return ""
}

View File

@@ -0,0 +1,33 @@
package util
import (
"regexp"
"strings"
)
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
func resolveAllowedToolNameWithLooseMatch(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
if _, ok := allowed[name]; ok {
return name
}
lower := strings.ToLower(strings.TrimSpace(name))
if canonical, ok := allowedCanonical[lower]; ok {
return canonical
}
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
return canonical
}
}
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
if loose == "" {
return ""
}
for candidateLower, canonical := range allowedCanonical {
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
return canonical
}
}
return ""
}

View File

@@ -16,7 +16,6 @@ type ToolCallParseResult struct {
RejectedByPolicy bool RejectedByPolicy bool
RejectedToolNames []string RejectedToolNames []string
} }
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseToolCallsDetailed(text, availableToolNames).Calls return ParseToolCallsDetailed(text, availableToolNames).Calls
} }
@@ -26,23 +25,36 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
if strings.TrimSpace(text) == "" { if strings.TrimSpace(text) == "" {
return result return result
} }
text = stripFencedCodeBlocks(text) result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
if strings.TrimSpace(text) == "" {
return result
}
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
candidates := buildToolCallCandidates(text) candidates := buildToolCallCandidates(text)
var parsed []ParsedToolCall var parsed []ParsedToolCall
for _, candidate := range candidates { for _, candidate := range candidates {
if tc := parseToolCallsPayload(candidate); len(tc) > 0 { tc := parseToolCallsPayload(candidate)
if len(tc) == 0 {
tc = parseXMLToolCalls(candidate)
}
if len(tc) == 0 {
tc = parseMarkupToolCalls(candidate)
}
if len(tc) == 0 {
tc = parseTextKVToolCalls(candidate)
}
if len(tc) > 0 {
parsed = tc parsed = tc
result.SawToolCallSyntax = true result.SawToolCallSyntax = true
break break
} }
} }
if len(parsed) == 0 { if len(parsed) == 0 {
return result parsed = parseXMLToolCalls(text)
if len(parsed) == 0 {
parsed = parseTextKVToolCalls(text)
if len(parsed) == 0 {
return result
}
}
result.SawToolCallSyntax = true
} }
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
@@ -51,7 +63,6 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result return result
} }
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall { func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
} }
@@ -62,76 +73,112 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string)
if trimmed == "" { if trimmed == "" {
return result return result
} }
if looksLikeToolExampleContext(trimmed) { result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
return result candidates := buildToolCallCandidates(trimmed)
} var parsed []ParsedToolCall
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
candidates := []string{trimmed}
for _, candidate := range candidates { for _, candidate := range candidates {
candidate = strings.TrimSpace(candidate) candidate = strings.TrimSpace(candidate)
if candidate == "" { if candidate == "" {
continue continue
} }
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") { parsed = parseToolCallsPayload(candidate)
continue if len(parsed) == 0 {
parsed = parseXMLToolCalls(candidate)
} }
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 { if len(parsed) == 0 {
result.SawToolCallSyntax = true parsed = parseMarkupToolCalls(candidate)
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) }
result.Calls = calls if len(parsed) == 0 {
result.RejectedToolNames = rejectedNames parsed = parseTextKVToolCalls(candidate)
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 }
return result if len(parsed) > 0 {
break
} }
} }
if len(parsed) == 0 {
parsed = parseXMLToolCalls(trimmed)
if len(parsed) == 0 {
parsed = parseTextKVToolCalls(trimmed)
if len(parsed) == 0 {
return result
}
}
}
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result return result
} }
func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) { func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) {
allowed := map[string]struct{}{} allowed := map[string]struct{}{}
allowedCanonical := map[string]string{}
for _, name := range availableToolNames { for _, name := range availableToolNames {
allowed[name] = struct{}{} trimmed := strings.TrimSpace(name)
if trimmed == "" {
continue
}
allowed[trimmed] = struct{}{}
lower := strings.ToLower(trimmed)
if _, exists := allowedCanonical[lower]; !exists {
allowedCanonical[lower] = trimmed
}
} }
if len(allowed) == 0 { if len(allowed) == 0 {
rejectedSet := map[string]struct{}{} rejectedSet := map[string]struct{}{}
rejected := make([]string, 0, len(parsed))
for _, tc := range parsed { for _, tc := range parsed {
if tc.Name == "" { if tc.Name == "" {
continue continue
} }
if _, ok := rejectedSet[tc.Name]; ok {
continue
}
rejectedSet[tc.Name] = struct{}{} rejectedSet[tc.Name] = struct{}{}
} rejected = append(rejected, tc.Name)
rejected := make([]string, 0, len(rejectedSet))
for name := range rejectedSet {
rejected = append(rejected, name)
} }
return nil, rejected return nil, rejected
} }
out := make([]ParsedToolCall, 0, len(parsed)) out := make([]ParsedToolCall, 0, len(parsed))
rejectedSet := map[string]struct{}{} rejectedSet := map[string]struct{}{}
rejected := make([]string, 0)
for _, tc := range parsed { for _, tc := range parsed {
if tc.Name == "" { if tc.Name == "" {
continue continue
} }
if _, ok := allowed[tc.Name]; !ok { matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
rejectedSet[tc.Name] = struct{}{} if matchedName == "" {
if _, ok := rejectedSet[tc.Name]; !ok {
rejectedSet[tc.Name] = struct{}{}
rejected = append(rejected, tc.Name)
}
continue continue
} }
tc.Name = matchedName
if tc.Input == nil { if tc.Input == nil {
tc.Input = map[string]any{} tc.Input = map[string]any{}
} }
out = append(out, tc) out = append(out, tc)
} }
rejected := make([]string, 0, len(rejectedSet))
for name := range rejectedSet {
rejected = append(rejected, name)
}
return out, rejected return out, rejected
} }
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical)
}
func parseToolCallsPayload(payload string) []ParsedToolCall { func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil { if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
return nil // Try to repair backslashes first! Because LLMs often mix these two problems.
repaired := repairInvalidJSONBackslashes(payload)
// Try loose repair on top of that
repaired = RepairLooseJSON(repaired)
if err := json.Unmarshal([]byte(repaired), &decoded); err != nil {
return nil
}
} }
switch v := decoded.(type) { switch v := decoded.(type) {
case map[string]any: case map[string]any:
@@ -147,6 +194,15 @@ func parseToolCallsPayload(payload string) []ParsedToolCall {
return nil return nil
} }
func looksLikeToolCallSyntax(text string) bool {
lower := strings.ToLower(text)
return strings.Contains(lower, "tool_calls") ||
strings.Contains(lower, "<tool_call") ||
strings.Contains(lower, "<function_call") ||
strings.Contains(lower, "<invoke") ||
strings.Contains(lower, "function.name:")
}
func parseToolCallList(v any) []ParsedToolCall { func parseToolCallList(v any) []ParsedToolCall {
items, ok := v.([]any) items, ok := v.([]any)
if !ok { if !ok {
@@ -199,32 +255,3 @@ func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
Input: parseToolCallInput(inputRaw), Input: parseToolCallInput(inputRaw),
}, true }, true
} }
func parseToolCallInput(v any) map[string]any {
switch x := v.(type) {
case nil:
return map[string]any{}
case map[string]any:
return x
case string:
raw := strings.TrimSpace(x)
if raw == "" {
return map[string]any{}
}
var parsed map[string]any
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
return parsed
}
return map[string]any{"_raw": raw}
default:
b, err := json.Marshal(x)
if err != nil {
return map[string]any{}
}
var parsed map[string]any
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
return parsed
}
return map[string]any{}
}
}

View File

@@ -0,0 +1,235 @@
package util
import (
"encoding/json"
"encoding/xml"
"regexp"
"strings"
)
var xmlToolCallPattern = regexp.MustCompile(`(?is)<tool_call>\s*(.*?)\s*</tool_call>`)
var functionCallPattern = regexp.MustCompile(`(?is)<function_call>\s*([^<]+?)\s*</function_call>`)
var functionParamPattern = regexp.MustCompile(`(?is)<function\s+parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</function\s+parameter>`)
var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?function_call>`)
var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?argument>`)
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
func parseXMLToolCalls(text string) []ParsedToolCall {
matches := xmlToolCallPattern.FindAllString(text, -1)
out := make([]ParsedToolCall, 0, len(matches)+1)
for _, block := range matches {
call, ok := parseSingleXMLToolCall(block)
if !ok {
continue
}
out = append(out, call)
}
if len(out) > 0 {
return out
}
if call, ok := parseFunctionCallTagStyle(text); ok {
return []ParsedToolCall{call}
}
if calls := parseAntmlFunctionCallStyles(text); len(calls) > 0 {
return calls
}
if call, ok := parseInvokeFunctionCallStyle(text); ok {
return []ParsedToolCall{call}
}
return nil
}
func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
inner := strings.TrimSpace(block)
inner = strings.TrimPrefix(inner, "<tool_call>")
inner = strings.TrimSuffix(inner, "</tool_call>")
inner = strings.TrimSpace(inner)
if strings.HasPrefix(inner, "{") {
var payload map[string]any
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
name := strings.TrimSpace(asString(payload["tool"]))
if name == "" {
name = strings.TrimSpace(asString(payload["tool_name"]))
}
if name != "" {
input := map[string]any{}
if params, ok := payload["params"].(map[string]any); ok {
input = params
} else if params, ok := payload["parameters"].(map[string]any); ok {
input = params
}
return ParsedToolCall{Name: name, Input: input}, true
}
}
}
dec := xml.NewDecoder(strings.NewReader(block))
name := ""
params := map[string]any{}
inParams := false
inTool := false
for {
tok, err := dec.Token()
if err != nil {
break
}
switch t := tok.(type) {
case xml.StartElement:
tag := strings.ToLower(t.Name.Local)
switch tag {
case "tool":
inTool = true
for _, attr := range t.Attr {
if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") && strings.TrimSpace(name) == "" {
name = strings.TrimSpace(attr.Value)
}
}
case "parameters":
inParams = true
case "tool_name", "name":
var v string
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
name = strings.TrimSpace(v)
}
case "input", "arguments", "argument", "args", "params":
var v string
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
for k, vv := range parsed {
params[k] = vv
}
}
}
default:
if inParams || inTool {
var v string
if err := dec.DecodeElement(&v, &t); err == nil {
params[t.Name.Local] = strings.TrimSpace(v)
}
}
}
case xml.EndElement:
tag := strings.ToLower(t.Name.Local)
if tag == "parameters" {
inParams = false
}
if tag == "tool" {
inTool = false
}
}
}
if strings.TrimSpace(name) == "" {
return ParsedToolCall{}, false
}
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
}
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
m := functionCallPattern.FindStringSubmatch(text)
if len(m) < 2 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
if name == "" {
return ParsedToolCall{}, false
}
input := map[string]any{}
for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) {
if len(pm) < 3 {
continue
}
key := strings.TrimSpace(pm[1])
val := strings.TrimSpace(pm[2])
if key != "" {
input[key] = val
}
}
return ParsedToolCall{Name: name, Input: input}, true
}
func parseAntmlFunctionCallStyles(text string) []ParsedToolCall {
matches := antmlFunctionCallPattern.FindAllStringSubmatch(text, -1)
if len(matches) == 0 {
return nil
}
out := make([]ParsedToolCall, 0, len(matches))
for _, m := range matches {
if call, ok := parseSingleAntmlFunctionCallMatch(m); ok {
out = append(out, call)
}
}
if len(out) == 0 {
return nil
}
return out
}
func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
if name == "" {
return ParsedToolCall{}, false
}
body := strings.TrimSpace(m[2])
input := map[string]any{}
if strings.HasPrefix(body, "{") {
if err := json.Unmarshal([]byte(body), &input); err == nil {
return ParsedToolCall{Name: name, Input: input}, true
}
}
if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 {
if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil {
return ParsedToolCall{Name: name, Input: input}, true
}
}
for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) {
if len(am) < 3 {
continue
}
k := strings.TrimSpace(am[1])
v := strings.TrimSpace(am[2])
if k != "" {
input[k] = v
}
}
return ParsedToolCall{Name: name, Input: input}, true
}
func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
m := invokeCallPattern.FindStringSubmatch(text)
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
if name == "" {
return ParsedToolCall{}, false
}
input := map[string]any{}
for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) {
if len(pm) < 3 {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(pm[2])
if k != "" {
input[k] = v
}
}
if len(input) == 0 {
if argsRaw := findMarkupTagValue(m[2], toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
input = parseMarkupInput(argsRaw)
} else if kv := parseMarkupKVObject(m[2]); len(kv) > 0 {
input = kv
}
}
return ParsedToolCall{Name: name, Input: input}, true
}
func asString(v any) string {
s, _ := v.(string)
return s
}

View File

@@ -1,6 +1,9 @@
package util package util
import "testing" import (
"strings"
"testing"
)
func TestParseToolCalls(t *testing.T) { func TestParseToolCalls(t *testing.T) {
text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix` text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix`
@@ -19,8 +22,8 @@ func TestParseToolCalls(t *testing.T) {
func TestParseToolCallsFromFencedJSON(t *testing.T) { func TestParseToolCallsFromFencedJSON(t *testing.T) {
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```" text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
calls := ParseToolCalls(text, []string{"search"}) calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 0 { if len(calls) != 1 {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) t.Fatalf("expected fenced tool_call payload to be parsed, got %#v", calls)
} }
} }
@@ -46,6 +49,17 @@ func TestParseToolCallsRejectsUnknownToolName(t *testing.T) {
} }
} }
func TestParseToolCallsAllowsCaseInsensitiveToolNameAndCanonicalizes(t *testing.T) {
text := `{"tool_calls":[{"name":"Bash","input":{"command":"ls -al"}}]}`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
}
func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) { func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}` text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
res := ParseToolCallsDetailed(text, []string{"search"}) res := ParseToolCallsDetailed(text, []string{"search"})
@@ -85,10 +99,10 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
} }
} }
func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) { func TestParseStandaloneToolCallsSupportsMixedProsePayload(t *testing.T) {
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 { if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 1 {
t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls) t.Fatalf("expected standalone parser to parse mixed prose payload, got %#v", calls)
} }
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
@@ -98,9 +112,408 @@ func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
} }
} }
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) { func TestParseStandaloneToolCallsParsesFencedCodeBlock(t *testing.T) {
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```" fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 { if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 1 {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) t.Fatalf("expected fenced tool_call payload to be parsed, got %#v", calls)
}
}
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
calls := ParseToolCalls(text, []string{"search_web"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "search_web" {
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
}
}
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
calls := ParseToolCalls(text, []string{"read_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "read_file" {
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
}
}
func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command><description>show cwd</description></parameters></tool_call>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsDetailedMarksXMLToolCallSyntax(t *testing.T) {
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`
res := ParseToolCallsDetailed(text, []string{"bash"})
if !res.SawToolCallSyntax {
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
}
if len(res.Calls) != 1 {
t.Fatalf("expected one parsed call, got %#v", res)
}
}
func TestParseToolCallsSupportsClaudeXMLJSONToolCall(t *testing.T) {
text := `<tool_call>{"tool":"Bash","params":{"command":"pwd","description":"show cwd"}}</tool_call>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsFunctionCallTagStyle(t *testing.T) {
text := `<function_call>Bash</function_call><function parameter name="command">ls -la</function parameter><function parameter name="description">list</function parameter>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "ls -la" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsAntmlFunctionCallStyle(t *testing.T) {
text := `<antml:function_calls><antml:function_call name="Bash">{"command":"pwd","description":"x"}</antml:function_call></antml:function_calls>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsAntmlArgumentStyle(t *testing.T) {
text := `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument><antml:argument name="description">x</antml:argument></antml:function_call></antml:function_calls>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
text := `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter><parameter name="description">d</parameter></invoke></function_calls>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) {
text := `<tool_call><tool name="Bash"><command>pwd</command><description>show cwd</description></tool></tool_call>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testing.T) {
text := `<antml:function_calls><antml:function_call id="x" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "bash" {
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
}
if calls[0].Input["command"] != "pwd" {
t.Fatalf("expected command argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsMultipleAntmlFunctionCalls(t *testing.T) {
text := `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call><antml:function_call id="2" function="Read"><antml:parameters>{"file_path":"README.md"}</antml:parameters></antml:function_call></antml:function_calls>`
calls := ParseToolCalls(text, []string{"bash", "read"})
if len(calls) != 2 {
t.Fatalf("expected 2 calls, got %#v", calls)
}
if calls[0].Name != "bash" || calls[1].Name != "read" {
t.Fatalf("expected canonical names [bash read], got %#v", calls)
}
}
func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
text := `<tool_call><name>read_file</function><arguments>{"path":"README.md"}</arguments></tool_call>`
calls := ParseToolCalls(text, []string{"read_file"})
if len(calls) != 0 {
t.Fatalf("expected mismatched tags to be rejected, got %#v", calls)
}
}
func TestRepairInvalidJSONBackslashes(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{"path": "C:\Users\name"}`, `{"path": "C:\\Users\name"}`},
{`{"cmd": "cd D:\git_codes"}`, `{"cmd": "cd D:\\git_codes"}`},
{`{"text": "line1\nline2"}`, `{"text": "line1\nline2"}`},
{`{"path": "D:\\back\\slash"}`, `{"path": "D:\\back\\slash"}`},
{`{"unicode": "\u2705"}`, `{"unicode": "\u2705"}`},
{`{"invalid_u": "\u123"}`, `{"invalid_u": "\\u123"}`},
}
for _, tt := range tests {
got := repairInvalidJSONBackslashes(tt.input)
if got != tt.expected {
t.Errorf("repairInvalidJSONBackslashes(%s) = %s; want %s", tt.input, got, tt.expected)
}
}
}
func TestRepairLooseJSON(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{tool_calls: [{"name": "search", "input": {"q": "go"}}]}`, `{"tool_calls": [{"name": "search", "input": {"q": "go"}}]}`},
{`{name: "search", input: {q: "go"}}`, `{"name": "search", "input": {"q": "go"}}`},
}
for _, tt := range tests {
got := RepairLooseJSON(tt.input)
if got != tt.expected {
t.Errorf("RepairLooseJSON(%s) = %s; want %s", tt.input, got, tt.expected)
}
}
}
func TestParseToolCallsWithUnquotedKeys(t *testing.T) {
text := `这里是列表:{tool_calls: [{"name": "todowrite", "input": {"todos": "test"}}]}`
availableTools := []string{"todowrite"}
parsed := ParseToolCalls(text, availableTools)
if len(parsed) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(parsed))
}
if parsed[0].Name != "todowrite" {
t.Errorf("expected tool todowrite, got %s", parsed[0].Name)
}
}
func TestParseToolCallsWithInvalidBackslashes(t *testing.T) {
// DeepSeek sometimes outputs Windows paths with single backslashes in JSON strings
// Note: using raw string to simulate what AI actually sends in the stream
text := `好的,执行以下命令:{"name": "execute_command", "input": "{\"command\": \"cd D:\git_codes && dir\"}"}`
availableTools := []string{"execute_command"}
parsed := ParseToolCalls(text, availableTools)
// If standard JSON fails, buildToolCallCandidates should still extract the object,
// and parseToolCallsPayload should repair it.
if len(parsed) != 1 {
// If it still fails, let's see why
candidates := buildToolCallCandidates(text)
t.Logf("Candidates: %v", candidates)
t.Fatalf("expected 1 tool call, got %d", len(parsed))
}
cmd, ok := parsed[0].Input["command"].(string)
if !ok {
t.Fatalf("expected command string in input, got %v", parsed[0].Input)
}
expected := "cd D:\\git_codes && dir"
if cmd != expected {
t.Errorf("expected command %q, got %q", expected, cmd)
}
}
func TestParseToolCallsWithDeepSeekHallucination(t *testing.T) {
// 模拟 DeepSeek 典型的幻觉输出:未加引号的键名 + 包含 Windows 路径的嵌套 JSON 字符串 + 漏掉列表的方括号
text := `检测到实施意图——实现经典算法。需在misc/目录创建Python文件。
关键约束:
1. Windows UTF-8编码处理
2. 必须用绝对路径导入
3. 禁止write覆盖已有文件misc/目录允许创建新文件)
将任务分解并委托:
- 研究8皇后算法模式并行探索
- 实现带可视化输出的解决方案unspecified-high
先创建todo列表追踪步骤。
{tool_calls: [{"name": "todowrite", "input": {"todos": {"content": "研究8皇后问题算法模式回溯法和输出格式", "status": "pending", "priority": "high"}, {"content": "在misc/目录创建8皇后Python脚本包含完整解决方案和可视化输出", "status": "pending", "priority": "high"}, {"content": "验证脚本正确性(运行测试)", "status": "pending", "priority": "medium"}}}]}`
availableTools := []string{"todowrite"}
parsed := ParseToolCalls(text, availableTools)
if len(parsed) != 1 {
cands := buildToolCallCandidates(text)
for i, c := range cands {
t.Logf("CAND %d: %s", i, c)
repaired := RepairLooseJSON(c)
t.Logf(" REPAIRED: %s", repaired)
}
t.Fatalf("expected 1 tool call, got %d. Candidates: %v", len(parsed), buildToolCallCandidates(text))
}
if parsed[0].Name != "todowrite" {
t.Errorf("expected tool name 'todowrite', got %q", parsed[0].Name)
}
todos, ok := parsed[0].Input["todos"].([]any)
if !ok {
t.Fatalf("expected 'todos' to be parsed as a list, got %T: %#v", parsed[0].Input["todos"], parsed[0].Input["todos"])
}
if len(todos) != 3 {
t.Errorf("expected 3 todo items, got %d", len(todos))
}
}
func TestParseToolCallsWithMixedWindowsPaths(t *testing.T) {
// 更复杂的案例:嵌套 JSON 字符串中的反斜杠未转义
text := `关键约束: 1. Windows UTF-8编码处理 2. 必须用绝对路径导入 D:\git_codes\ds2api\misc
{tool_calls: [{"name": "write_file", "input": "{\"path\": \"D:\\git_codes\\ds2api\\misc\\queens.py\", \"content\": \"print('hello')\"}"}]}`
availableTools := []string{"write_file"}
parsed := ParseToolCalls(text, availableTools)
if len(parsed) != 1 {
t.Fatalf("expected 1 tool call from mixed text with paths, got %d", len(parsed))
}
path, _ := parsed[0].Input["path"].(string)
// 在解析后的 Go map 中,反斜杠应该被还原
if !strings.Contains(path, "D:\\git_codes") && !strings.Contains(path, "D:/git_codes") {
t.Errorf("expected path to contain Windows style separators, got %q", path)
}
}
func TestParseToolCallInputRepairsControlCharsInPath(t *testing.T) {
in := `{"path":"D:\tmp\new\readme.txt","content":"line1\nline2"}`
parsed := parseToolCallInput(in)
path, ok := parsed["path"].(string)
if !ok {
t.Fatalf("expected path string in parsed input, got %#v", parsed["path"])
}
if path != `D:\tmp\new\readme.txt` {
t.Fatalf("expected repaired windows path, got %q", path)
}
content, ok := parsed["content"].(string)
if !ok {
t.Fatalf("expected content string in parsed input, got %#v", parsed["content"])
}
if content != "line1\nline2" {
t.Fatalf("expected non-path field to keep decoded escapes, got %q", content)
}
}
func TestRepairLooseJSONWithNestedObjects(t *testing.T) {
// 测试嵌套对象的修复DeepSeek 幻觉输出,每个元素内部包含嵌套 {}
// 注意:正则只支持单层嵌套,不支持更深层次的嵌套
tests := []struct {
name string
input string
expected string
}{
// 1. 单层嵌套对象(核心修复目标)
{
name: "单层嵌套 - 2个元素",
input: `"todos": {"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}`,
expected: `"todos": [{"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}]`,
},
// 2. 3个单层嵌套对象
{
name: "3个单层嵌套对象",
input: `"items": {"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}`,
expected: `"items": [{"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}]`,
},
// 3. 混合嵌套:有些字段是对象,有些是原始值
{
name: "混合嵌套 - 对象和原始值混合",
input: `"items": {"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}`,
expected: `"items": [{"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}]`,
},
// 4. 4个嵌套对象边界测试
{
name: "4个嵌套对象",
input: `"todos": {"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}`,
expected: `"todos": [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}]`,
},
// 5. DeepSeek 典型幻觉:无空格逗号分隔
{
name: "无空格逗号分隔",
input: `"results": {"name": "a"}, {"name": "b"}, {"name": "c"}`,
expected: `"results": [{"name": "a"}, {"name": "b"}, {"name": "c"}]`,
},
// 6. 嵌套数组(数组在对象内,不是深层嵌套)
{
name: "对象内包含数组",
input: `"data": {"items": [1,2,3]}, {"items": [4,5,6]}`,
expected: `"data": [{"items": [1,2,3]}, {"items": [4,5,6]}]`,
},
// 7. 真实的 DeepSeek 8皇后问题输出
{
name: "DeepSeek 8皇后真实输出",
input: `"todos": {"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}`,
expected: `"todos": [{"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}]`,
},
// 8. 简单无嵌套对象(回归测试)
{
name: "简单无嵌套对象",
input: `"items": {"a": 1}, {"b": 2}`,
expected: `"items": [{"a": 1}, {"b": 2}]`,
},
// 9. 更复杂的单层嵌套
{
name: "复杂单层嵌套",
input: `"functions": {"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}`,
expected: `"functions": [{"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}]`,
},
// 10. 5个嵌套对象
{
name: "5个嵌套对象",
input: `"tasks": {"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}`,
expected: `"tasks": [{"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}]`,
},
}
for _, tt := range tests {
got := RepairLooseJSON(tt.input)
if got != tt.expected {
t.Errorf("[%s] RepairLooseJSON with nested objects:\n input: %s\n got: %s\n expected: %s", tt.name, tt.input, got, tt.expected)
}
} }
} }

View File

@@ -0,0 +1,55 @@
package util
import (
"regexp"
"strings"
)
var textKVNamePattern = regexp.MustCompile(`(?is)function\.name:\s*([a-zA-Z0-9_\-.]+)`)
func parseTextKVToolCalls(text string) []ParsedToolCall {
var out []ParsedToolCall
matches := textKVNamePattern.FindAllStringSubmatchIndex(text, -1)
if len(matches) == 0 {
return nil
}
for i, match := range matches {
name := text[match[2]:match[3]]
offset := match[1]
endSearch := len(text)
if i+1 < len(matches) {
endSearch = matches[i+1][0]
}
searchArea := text[offset:endSearch]
argIdx := strings.Index(searchArea, "function.arguments:")
if argIdx < 0 {
continue
}
startIdx := offset + argIdx + len("function.arguments:")
braceIdx := strings.IndexByte(text[startIdx:endSearch], '{')
if braceIdx < 0 {
continue
}
actualStart := startIdx + braceIdx
objJson, _, ok := extractJSONObject(text, actualStart)
if !ok {
continue
}
input := parseToolCallInput(objJson)
out = append(out, ParsedToolCall{
Name: name,
Input: input,
})
}
if len(out) == 0 {
return nil
}
return out
}

Some files were not shown because too many files have changed in this diff Show More