mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 15:35:27 +08:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c59a0b7799 | ||
|
|
bd72b91f27 | ||
|
|
9240f85246 | ||
|
|
ea4bd1e483 | ||
|
|
9e0de62707 | ||
|
|
128de290db | ||
|
|
286d266723 | ||
|
|
8aad1005b2 | ||
|
|
11b2f24fc2 | ||
|
|
d1f08cbb89 | ||
|
|
60e9d707d4 | ||
|
|
9b93badb57 | ||
|
|
892213071a | ||
|
|
5484d6e59d | ||
|
|
0ce3fd22a7 | ||
|
|
25e40cc3a6 | ||
|
|
af68d21095 | ||
|
|
1fafd25e86 | ||
|
|
5f8f28a943 | ||
|
|
94cf1bfcc7 | ||
|
|
13562cf521 | ||
|
|
75969e710d | ||
|
|
6c39c8e191 | ||
|
|
0e261ff0a0 | ||
|
|
fab326eca1 | ||
|
|
c033eceee7 | ||
|
|
a10e03ebe0 | ||
|
|
a6aa4a1839 |
7
API.md
7
API.md
@@ -284,6 +284,11 @@ data: [DONE]
|
||||
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||
|
||||
补充说明:
|
||||
|
||||
- **非代码块上下文**下,工具 JSON 即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||
|
||||
---
|
||||
|
||||
### `GET /v1/models/{id}`
|
||||
@@ -301,7 +306,7 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略(含代码块示例豁免) |
|
||||
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||
|
||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||
|
||||
@@ -106,6 +106,14 @@ flowchart LR
|
||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||
|
||||
|
||||
#### Claude Code 接入避坑(实测)
|
||||
|
||||
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||
|
||||
### Gemini 接口
|
||||
|
||||
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||
|
||||
@@ -106,6 +106,14 @@ flowchart LR
|
||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||
|
||||
|
||||
#### Claude Code integration pitfalls (validated)
|
||||
|
||||
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||
|
||||
### Gemini Endpoint
|
||||
|
||||
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
||||
@@ -351,6 +359,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
- In non-code-block context, tool JSON may still be recognized even when mixed with normal prose; surrounding prose can remain as text output.
|
||||
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||
|
||||
@@ -51,7 +51,7 @@ DS2API 提供两个层级的测试:
|
||||
1. **Preflight 检查**:
|
||||
- `go test ./... -count=1`(单元测试)
|
||||
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测)
|
||||
- `node --test`(如仓库存在 Node 单测文件时执行;当前默认以 Go 测试 + Node 语法门禁为主)
|
||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||
|
||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||
|
||||
40
docs/toolcall-semantics.md
Normal file
40
docs/toolcall-semantics.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Tool call parsing semantics (Go canonical spec)
|
||||
|
||||
This document defines the cross-runtime contract for `ParseToolCallsDetailed` / `parseToolCallsDetailed`.
|
||||
|
||||
## Output contract
|
||||
|
||||
- `calls`: accepted tool calls with normalized tool names.
|
||||
- `sawToolCallSyntax`: true when tool-call-like syntax is detected (`tool_calls`, `<tool_call>`, `<function_call>`, `<invoke>`) or a valid call is parsed.
|
||||
- `rejectedByPolicy`: true when parser extracted call syntax but all calls are rejected by allow-list policy.
|
||||
- `rejectedToolNames`: de-duplicated rejected tool names in first-seen order.
|
||||
|
||||
## Parse pipeline
|
||||
|
||||
1. Strip fenced code blocks for non-standalone parsing.
|
||||
2. Build candidates from:
|
||||
- full text,
|
||||
- fenced JSON snippets,
|
||||
- extracted JSON objects around `tool_calls`,
|
||||
- first `{` to last `}` object slice.
|
||||
3. Parse each candidate in order:
|
||||
- JSON payload parser (`tool_calls`, list, single call object),
|
||||
- markup parser (`<tool_call>`, `<function_call>`, `<invoke>`; supports attributes + nested fields).
|
||||
4. Stop at first candidate that yields at least one call.
|
||||
|
||||
## Name normalization policy
|
||||
|
||||
When matching parsed names against configured tools:
|
||||
|
||||
1. exact match,
|
||||
2. case-insensitive match,
|
||||
3. namespace tail match (`a.b.c` => `c`),
|
||||
4. loose alnum match (remove non `[a-z0-9]`, compare).
|
||||
|
||||
## Standalone mode
|
||||
|
||||
Standalone mode (`ParseStandaloneToolCallsDetailed`) parses the whole input directly (no candidate slicing), while still applying:
|
||||
|
||||
- example-context guard,
|
||||
- JSON then markup fallback,
|
||||
- the same allow-list normalization policy.
|
||||
@@ -315,3 +315,78 @@ func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
}{
|
||||
{name: "xml_tool_call", payload: `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`},
|
||||
{name: "xml_json_tool_call", payload: `<tool_call>{"tool":"Bash","params":{"command":"pwd"}}</tool_call>`},
|
||||
{name: "nested_tool_tag_style", payload: `<tool_call><tool name="Bash"><command>pwd</command></tool></tool_call>`},
|
||||
{name: "function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`},
|
||||
{name: "antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`},
|
||||
{name: "antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`},
|
||||
{name: "invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeDoesNotStopOnUnclosedFencedToolExample(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
t.Fatalf("unexpected tool_use for fenced example, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
foundEndTurn := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "end_turn" {
|
||||
foundEndTurn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundEndTurn {
|
||||
t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,8 +125,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
if !containsStr(prompt, "Search the web") {
|
||||
t.Fatalf("expected description in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "tool_calls") {
|
||||
t.Fatalf("expected tool_calls instruction in prompt")
|
||||
if !containsStr(prompt, "tool_use") {
|
||||
t.Fatalf("expected tool_use instruction in prompt")
|
||||
}
|
||||
if containsStr(prompt, "tool_calls") {
|
||||
t.Fatalf("expected prompt to avoid tool_calls JSON instruction")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ func buildClaudeToolPrompt(tools []any) string {
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts,
|
||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
||||
"When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
|
||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeStreamRuntime struct {
|
||||
@@ -116,6 +117,18 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferToolContent {
|
||||
if hasUnclosedCodeFence(s.text.String()) {
|
||||
continue
|
||||
}
|
||||
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
|
||||
if len(detected) > 0 {
|
||||
s.finalize("tool_use")
|
||||
return streamengine.ParsedDecision{
|
||||
ContentSeen: true,
|
||||
Stop: true,
|
||||
StopReason: streamengine.StopReason("tool_use_detected"),
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
@@ -144,3 +157,7 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func hasUnclosedCodeFence(text string) bool {
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{
|
||||
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
candidates, _ := out["candidates"].([]any)
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
part0, _ := parts[0].(map[string]any)
|
||||
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||
if functionCall["name"] != "eval_javascript" {
|
||||
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"hello "}`,
|
||||
|
||||
@@ -513,8 +513,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -531,11 +531,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,8 +552,8 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -573,11 +570,9 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,8 +591,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -614,8 +609,45 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```json") || !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected fenced tool snippet in content, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
@@ -640,8 +672,8 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -655,14 +687,11 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") {
|
||||
t.Fatalf("expected embedded tool json to remain in text, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
@@ -33,9 +34,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
"role": "user",
|
||||
"content": formatToolResultForPrompt(msg),
|
||||
})
|
||||
case "user", "system":
|
||||
case "user", "system", "developer":
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||
})
|
||||
default:
|
||||
@@ -47,7 +48,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
role = "user"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
@@ -77,7 +78,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) strin
|
||||
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||
}
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
continue
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||
@@ -132,32 +133,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
|
||||
}
|
||||
|
||||
func normalizeOpenAIContentForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, item := range x {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
if t != "text" && t != "output_text" && t != "input_text" {
|
||||
continue
|
||||
}
|
||||
if text := asString(m["text"]); text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
if text := asString(m["content"]); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
}
|
||||
return prompt.NormalizeContent(v)
|
||||
}
|
||||
|
||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||
@@ -189,6 +165,14 @@ func marshalToPromptString(v any) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if role == "developer" {
|
||||
return "system"
|
||||
}
|
||||
return role
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
|
||||
@@ -193,3 +193,90 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *
|
||||
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_missing_name",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"arguments": `{"path":"README.MD"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 0 {
|
||||
t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": nil,
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_screenshot",
|
||||
"function": map[string]any{
|
||||
"name": "send_file_to_user",
|
||||
"arguments": `{"file_path":"/tmp/a.png"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Contains(content, "<|Assistant|>null") || strings.HasPrefix(strings.TrimSpace(content), "null") {
|
||||
t.Fatalf("unexpected null literal injected into assistant tool history: %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "function.name: send_file_to_user") {
|
||||
t.Fatalf("expected tool history block preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{"role": "developer", "content": "必须先走工具调用"},
|
||||
map[string]any{"role": "user", "content": "你好"},
|
||||
}
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "system" {
|
||||
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextEmpty(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "", "content": "工具说明文本"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if content != "工具说明文本" {
|
||||
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
|
||||
@@ -94,6 +94,16 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
outputIndex := s.ensureMessageOutputIndex()
|
||||
text := s.visibleText.String()
|
||||
if s.messagePartAdded {
|
||||
s.sendEvent(
|
||||
"response.output_text.done",
|
||||
openaifmt.BuildResponsesTextDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
outputIndex,
|
||||
0,
|
||||
text,
|
||||
),
|
||||
)
|
||||
s.sendEvent(
|
||||
"response.content_part.done",
|
||||
openaifmt.BuildResponsesContentPartDonePayload(
|
||||
|
||||
@@ -226,6 +226,40 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_text.done") {
|
||||
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||
}
|
||||
textDoneIdx := strings.Index(body, "event: response.output_text.done")
|
||||
partDoneIdx := strings.Index(body, "event: response.content_part.done")
|
||||
if textDoneIdx < 0 || partDoneIdx < 0 {
|
||||
t.Fatalf("expected output_text.done + content_part.done, body=%s", body)
|
||||
}
|
||||
if textDoneIdx > partDoneIdx {
|
||||
t.Fatalf("expected output_text.done before content_part.done, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -288,12 +322,8 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) != 1 {
|
||||
t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
item, _ := addedPayloads[0]["item"].(map[string]any)
|
||||
if asString(item["type"]) != "message" {
|
||||
t.Fatalf("expected only message output item in strict mode, got %#v", item)
|
||||
if len(addedPayloads) < 1 {
|
||||
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
@@ -302,15 +332,22 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "message" {
|
||||
hasMessage = true
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
|
||||
@@ -15,19 +15,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
pending := state.pending.String()
|
||||
if strings.TrimSpace(pending) != "" {
|
||||
content := state.pendingToolRaw + pending
|
||||
state.pending.Reset()
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
} else {
|
||||
// Wait for either more non-whitespace content (demote to plain text)
|
||||
// or stream flush (promote to executable tool calls).
|
||||
return events
|
||||
}
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -45,7 +35,14 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if len(calls) > 0 {
|
||||
state.pendingToolRaw = captured
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
_ = captured
|
||||
state.pendingToolCalls = calls
|
||||
continue
|
||||
}
|
||||
@@ -211,11 +208,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
// Strict mode: only standalone tool payloads are executable. If the
|
||||
// payload is wrapped by non-whitespace prose, keep it as plain text.
|
||||
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
|
||||
@@ -73,22 +73,31 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
var expected struct {
|
||||
Calls []util.ParsedToolCall `json:"calls"`
|
||||
Calls []util.ParsedToolCall `json:"calls"`
|
||||
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
|
||||
RejectedByPolicy bool `json:"rejectedByPolicy"`
|
||||
RejectedToolNames []string `json:"rejectedToolNames"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
var got []util.ParsedToolCall
|
||||
var got util.ToolCallParseResult
|
||||
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||
case "standalone":
|
||||
got = util.ParseStandaloneToolCalls(fixture.Text, fixture.ToolNames)
|
||||
got = util.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
default:
|
||||
got = util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
||||
got = util.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
}
|
||||
if len(got) == 0 && len(expected.Calls) == 0 {
|
||||
continue
|
||||
if got.Calls == nil {
|
||||
got.Calls = []util.ParsedToolCall{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, expected.Calls) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls)
|
||||
if got.RejectedToolNames == nil {
|
||||
got.RejectedToolNames = []string{}
|
||||
}
|
||||
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
|
||||
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
|
||||
got.RejectedByPolicy != expected.RejectedByPolicy ||
|
||||
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,19 @@ func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, cont
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func BuildResponsesTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
|
||||
@@ -8,8 +8,11 @@ const {
|
||||
stripFencedCodeBlocks,
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
parseMarkupToolCalls,
|
||||
} = require('./parse_payload');
|
||||
|
||||
const TOOL_NAME_LOOSE_PATTERN = /[^a-z0-9]+/g;
|
||||
|
||||
function extractToolNames(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
return [];
|
||||
@@ -41,12 +44,15 @@ function parseToolCallsDetailed(text, toolNames) {
|
||||
if (!toStringSafe(sanitized)) {
|
||||
return result;
|
||||
}
|
||||
result.sawToolCallSyntax = sanitized.toLowerCase().includes('tool_calls');
|
||||
result.sawToolCallSyntax = looksLikeToolCallSyntax(sanitized);
|
||||
|
||||
const candidates = buildToolCallCandidates(sanitized);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(c);
|
||||
}
|
||||
if (parsed.length > 0) {
|
||||
result.sawToolCallSyntax = true;
|
||||
break;
|
||||
@@ -73,15 +79,17 @@ function parseStandaloneToolCallsDetailed(text, toolNames) {
|
||||
if (!trimmed) {
|
||||
return result;
|
||||
}
|
||||
if (trimmed.includes('```')) {
|
||||
return result;
|
||||
}
|
||||
if (looksLikeToolExampleContext(trimmed)) {
|
||||
return result;
|
||||
}
|
||||
result.sawToolCallSyntax = trimmed.toLowerCase().includes('tool_calls');
|
||||
if (!trimmed.startsWith('{') && !trimmed.startsWith('[')) {
|
||||
return result;
|
||||
result.sawToolCallSyntax = looksLikeToolCallSyntax(trimmed);
|
||||
let parsed = parseToolCallsPayload(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(trimmed);
|
||||
}
|
||||
|
||||
const parsed = parseToolCallsPayload(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
@@ -146,7 +154,7 @@ function filterToolCallsDetailed(parsed, toolNames) {
|
||||
if (allowed.has(tc.name)) {
|
||||
matchedName = tc.name;
|
||||
} else {
|
||||
matchedName = allowedCanonical.get(tc.name.toLowerCase()) || '';
|
||||
matchedName = resolveAllowedToolName(tc.name, allowed, allowedCanonical);
|
||||
}
|
||||
if (!matchedName) {
|
||||
if (!seenRejected.has(tc.name)) {
|
||||
@@ -163,6 +171,45 @@ function filterToolCallsDetailed(parsed, toolNames) {
|
||||
return { calls, rejectedToolNames: rejected };
|
||||
}
|
||||
|
||||
function resolveAllowedToolName(name, allowed, allowedCanonical) {
|
||||
const normalizedName = toStringSafe(name).trim();
|
||||
if (!normalizedName) {
|
||||
return '';
|
||||
}
|
||||
if (allowed.has(normalizedName)) {
|
||||
return normalizedName;
|
||||
}
|
||||
const lower = normalizedName.toLowerCase();
|
||||
if (allowedCanonical.has(lower)) {
|
||||
return allowedCanonical.get(lower);
|
||||
}
|
||||
const idx = lower.lastIndexOf('.');
|
||||
if (idx >= 0 && idx < lower.length - 1) {
|
||||
const tail = lower.slice(idx + 1);
|
||||
if (allowedCanonical.has(tail)) {
|
||||
return allowedCanonical.get(tail);
|
||||
}
|
||||
}
|
||||
const loose = lower.replace(TOOL_NAME_LOOSE_PATTERN, '');
|
||||
if (!loose) {
|
||||
return '';
|
||||
}
|
||||
for (const [candidateLower, canonical] of allowedCanonical.entries()) {
|
||||
if (candidateLower.replace(TOOL_NAME_LOOSE_PATTERN, '') === loose) {
|
||||
return canonical;
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
function looksLikeToolCallSyntax(text) {
|
||||
const lower = toStringSafe(text).toLowerCase();
|
||||
return lower.includes('tool_calls')
|
||||
|| lower.includes('<tool_call')
|
||||
|| lower.includes('<function_call')
|
||||
|| lower.includes('<invoke');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
extractToolNames,
|
||||
parseToolCalls,
|
||||
|
||||
@@ -1,6 +1,23 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?(tool_call|function_call|invoke)\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||
const TOOL_CALL_MARKUP_SELFCLOSE_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)\/>/gi;
|
||||
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||
const TOOL_CALL_MARKUP_ATTR_PATTERN = /(name|function|tool)\s*=\s*"([^"]+)"/i;
|
||||
const TOOL_CALL_MARKUP_NAME_PATTERNS = [
|
||||
/<(?:[a-z0-9_:-]+:)?name\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?name>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?function\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?function>/i,
|
||||
];
|
||||
const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
|
||||
/<(?:[a-z0-9_:-]+:)?input\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?input>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?arguments>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?argument\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?argument>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameters>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
|
||||
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
|
||||
];
|
||||
|
||||
const {
|
||||
toStringSafe,
|
||||
@@ -103,6 +120,112 @@ function parseToolCallsPayload(payload) {
|
||||
return one ? [one] : [];
|
||||
}
|
||||
|
||||
function parseMarkupToolCalls(text) {
|
||||
const raw = toStringSafe(text).trim();
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const m of raw.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
|
||||
const parsed = parseMarkupSingleToolCall(toStringSafe(m[2]).trim(), toStringSafe(m[3]).trim());
|
||||
if (parsed) {
|
||||
out.push(parsed);
|
||||
}
|
||||
}
|
||||
for (const m of raw.matchAll(TOOL_CALL_MARKUP_SELFCLOSE_PATTERN)) {
|
||||
const parsed = parseMarkupSingleToolCall(toStringSafe(m[1]).trim(), '');
|
||||
if (parsed) {
|
||||
out.push(parsed);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseMarkupSingleToolCall(attrs, inner) {
|
||||
const embedded = parseToolCallsPayload(inner);
|
||||
if (embedded.length > 0) {
|
||||
return embedded[0];
|
||||
}
|
||||
let name = '';
|
||||
const attrMatch = attrs.match(TOOL_CALL_MARKUP_ATTR_PATTERN);
|
||||
if (attrMatch && attrMatch[2]) {
|
||||
name = toStringSafe(attrMatch[2]).trim();
|
||||
}
|
||||
if (!name) {
|
||||
name = stripTagText(findMarkupTagValue(inner, TOOL_CALL_MARKUP_NAME_PATTERNS));
|
||||
}
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let input = {};
|
||||
const argsRaw = findMarkupTagValue(inner, TOOL_CALL_MARKUP_ARGS_PATTERNS);
|
||||
if (argsRaw) {
|
||||
input = parseMarkupInput(argsRaw);
|
||||
} else {
|
||||
const kv = parseMarkupKVObject(inner);
|
||||
if (Object.keys(kv).length > 0) {
|
||||
input = kv;
|
||||
}
|
||||
}
|
||||
return { name, input };
|
||||
}
|
||||
|
||||
function parseMarkupInput(raw) {
|
||||
const s = toStringSafe(raw).trim();
|
||||
if (!s) {
|
||||
return {};
|
||||
}
|
||||
const parsed = parseToolCallInput(s);
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) {
|
||||
return parsed;
|
||||
}
|
||||
const kv = parseMarkupKVObject(s);
|
||||
if (Object.keys(kv).length > 0) {
|
||||
return kv;
|
||||
}
|
||||
return { _raw: stripTagText(s) };
|
||||
}
|
||||
|
||||
function parseMarkupKVObject(text) {
|
||||
const raw = toStringSafe(text).trim();
|
||||
if (!raw) {
|
||||
return {};
|
||||
}
|
||||
const out = {};
|
||||
for (const m of raw.matchAll(TOOL_CALL_MARKUP_KV_PATTERN)) {
|
||||
const key = toStringSafe(m[1]).trim();
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
const valueRaw = stripTagText(m[2]);
|
||||
if (!valueRaw) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
out[key] = JSON.parse(valueRaw);
|
||||
} catch (_err) {
|
||||
out[key] = valueRaw;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function stripTagText(text) {
|
||||
return toStringSafe(text).replace(/<[^>]+>/g, ' ').trim();
|
||||
}
|
||||
|
||||
function findMarkupTagValue(text, patterns) {
|
||||
const source = toStringSafe(text);
|
||||
for (const p of patterns) {
|
||||
const m = source.match(p);
|
||||
if (m && m[1]) {
|
||||
return toStringSafe(m[1]);
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
function parseToolCallList(v) {
|
||||
if (!Array.isArray(v)) {
|
||||
return [];
|
||||
@@ -193,4 +316,5 @@ module.exports = {
|
||||
stripFencedCodeBlocks,
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
parseMarkupToolCalls,
|
||||
};
|
||||
|
||||
@@ -21,22 +21,14 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
}
|
||||
const events = [];
|
||||
|
||||
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||
const pending = state.pending || '';
|
||||
if (pending.trim() !== '') {
|
||||
const content = (state.pendingToolRaw || '') + pending;
|
||||
state.pending = '';
|
||||
state.pendingToolRaw = '';
|
||||
state.pendingToolCalls = [];
|
||||
noteText(state, content);
|
||||
events.push({ type: 'text', text: content });
|
||||
} else {
|
||||
return events;
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||
state.pendingToolRaw = '';
|
||||
state.pendingToolCalls = [];
|
||||
continue;
|
||||
}
|
||||
if (state.capturing) {
|
||||
if (state.pending) {
|
||||
state.capture += state.pending;
|
||||
|
||||
@@ -51,6 +51,9 @@ func MessagesPrepare(messages []map[string]any) string {
|
||||
}
|
||||
|
||||
func NormalizeContent(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
@@ -64,11 +67,11 @@ func NormalizeContent(v any) string {
|
||||
typeStr, _ := m["type"].(string)
|
||||
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
||||
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
||||
if txt, ok := m["text"].(string); ok {
|
||||
if txt, ok := m["text"].(string); ok && txt != "" {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
if txt, ok := m["content"].(string); ok {
|
||||
if txt, ok := m["content"].(string); ok && txt != "" {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
|
||||
32
internal/prompt/messages_test.go
Normal file
32
internal/prompt/messages_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package prompt
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeContentNilReturnsEmpty(t *testing.T) {
|
||||
if got := NormalizeContent(nil); got != "" {
|
||||
t.Fatalf("expected empty string for nil content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareNilContentNoNullLiteral(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "assistant", "content": nil},
|
||||
{"role": "user", "content": "ok"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if got == "" {
|
||||
t.Fatalf("expected non-empty output")
|
||||
}
|
||||
if got == "null" {
|
||||
t.Fatalf("expected no null literal output, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
got := NormalizeContent([]any{
|
||||
map[string]any{"type": "text", "text": "", "content": "from-content"},
|
||||
})
|
||||
if got != "from-content" {
|
||||
t.Fatalf("expected fallback to content when text is empty, got %q", got)
|
||||
}
|
||||
}
|
||||
161
internal/util/toolcalls_markup.go
Normal file
161
internal/util/toolcalls_markup.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var toolCallMarkupTagNames = []string{"tool_call", "function_call", "invoke"}
|
||||
var toolCallMarkupTagPatternByName = map[string]*regexp.Regexp{
|
||||
"tool_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?tool_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?tool_call>`),
|
||||
"function_call": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function_call\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?function_call>`),
|
||||
"invoke": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>(.*?)</(?:[a-z0-9_:-]+:)?invoke>`),
|
||||
}
|
||||
var toolCallMarkupSelfClosingPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)/>`)
|
||||
var toolCallMarkupKVPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?([a-z0-9_\-.]+)>`)
|
||||
var toolCallMarkupAttrPattern = regexp.MustCompile(`(?is)(name|function|tool)\s*=\s*"([^"]+)"`)
|
||||
var anyTagPattern = regexp.MustCompile(`(?is)<[^>]+>`)
|
||||
var toolCallMarkupNameTagNames = []string{"name", "function"}
|
||||
var toolCallMarkupNamePatternByTag = map[string]*regexp.Regexp{
|
||||
"name": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?name>`),
|
||||
"function": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function>`),
|
||||
}
|
||||
var toolCallMarkupArgsTagNames = []string{"input", "arguments", "argument", "parameters", "parameter", "args", "params"}
|
||||
var toolCallMarkupArgsPatternByTag = map[string]*regexp.Regexp{
|
||||
"input": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?input\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?input>`),
|
||||
"arguments": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?arguments\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?arguments>`),
|
||||
"argument": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?argument\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?argument>`),
|
||||
"parameters": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameters\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameters>`),
|
||||
"parameter": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?parameter\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?parameter>`),
|
||||
"args": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?args\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?args>`),
|
||||
"params": regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?params\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?params>`),
|
||||
}
|
||||
|
||||
func parseMarkupToolCalls(text string) []ParsedToolCall {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]ParsedToolCall, 0)
|
||||
for _, tagName := range toolCallMarkupTagNames {
|
||||
pattern := toolCallMarkupTagPatternByName[tagName]
|
||||
for _, m := range pattern.FindAllStringSubmatch(trimmed, -1) {
|
||||
if len(m) < 3 {
|
||||
continue
|
||||
}
|
||||
attrs := strings.TrimSpace(m[1])
|
||||
inner := strings.TrimSpace(m[2])
|
||||
if parsed := parseMarkupSingleToolCall(attrs, inner); parsed.Name != "" {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range toolCallMarkupSelfClosingPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||
if len(m) < 2 {
|
||||
continue
|
||||
}
|
||||
if parsed := parseMarkupSingleToolCall(strings.TrimSpace(m[1]), ""); parsed.Name != "" {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
|
||||
if parsed := parseToolCallsPayload(inner); len(parsed) > 0 {
|
||||
return parsed[0]
|
||||
}
|
||||
|
||||
name := ""
|
||||
if m := toolCallMarkupAttrPattern.FindStringSubmatch(attrs); len(m) >= 3 {
|
||||
name = strings.TrimSpace(m[2])
|
||||
}
|
||||
if name == "" {
|
||||
name = findMarkupTagValue(inner, toolCallMarkupNameTagNames, toolCallMarkupNamePatternByTag)
|
||||
}
|
||||
if name == "" {
|
||||
return ParsedToolCall{}
|
||||
}
|
||||
|
||||
input := map[string]any{}
|
||||
if argsRaw := findMarkupTagValue(inner, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
|
||||
input = parseMarkupInput(argsRaw)
|
||||
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
|
||||
input = kv
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}
|
||||
}
|
||||
|
||||
func parseMarkupInput(raw string) map[string]any {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
|
||||
return kv
|
||||
}
|
||||
return map[string]any{"_raw": stripTagText(raw)}
|
||||
}
|
||||
|
||||
func parseMarkupKVObject(text string) map[string]any {
|
||||
matches := toolCallMarkupKVPattern.FindAllStringSubmatch(strings.TrimSpace(text), -1)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
for _, m := range matches {
|
||||
if len(m) < 4 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(m[1])
|
||||
endKey := strings.TrimSpace(m[3])
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(key, endKey) {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(stripTagText(m[2]))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
var jsonValue any
|
||||
if json.Unmarshal([]byte(value), &jsonValue) == nil {
|
||||
out[key] = jsonValue
|
||||
continue
|
||||
}
|
||||
out[key] = value
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripTagText(text string) string {
|
||||
return strings.TrimSpace(anyTagPattern.ReplaceAllString(text, ""))
|
||||
}
|
||||
|
||||
func findMarkupTagValue(text string, tagNames []string, patternByTag map[string]*regexp.Regexp) string {
|
||||
for _, tag := range tagNames {
|
||||
pattern := patternByTag[tag]
|
||||
if pattern == nil {
|
||||
continue
|
||||
}
|
||||
if m := pattern.FindStringSubmatch(text); len(m) >= 2 {
|
||||
value := strings.TrimSpace(m[1])
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -2,9 +2,12 @@ package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
|
||||
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Input map[string]any `json:"input"`
|
||||
@@ -30,19 +33,30 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
|
||||
result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
for _, candidate := range candidates {
|
||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
||||
tc := parseToolCallsPayload(candidate)
|
||||
if len(tc) == 0 {
|
||||
tc = parseXMLToolCalls(candidate)
|
||||
}
|
||||
if len(tc) == 0 {
|
||||
tc = parseMarkupToolCalls(candidate)
|
||||
}
|
||||
if len(tc) > 0 {
|
||||
parsed = tc
|
||||
result.SawToolCallSyntax = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
return result
|
||||
parsed = parseXMLToolCalls(text)
|
||||
if len(parsed) == 0 {
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = true
|
||||
}
|
||||
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
@@ -65,17 +79,21 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string)
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
|
||||
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
|
||||
candidates := []string{trimmed}
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") {
|
||||
continue
|
||||
parsed := parseToolCallsPayload(candidate)
|
||||
if len(parsed) == 0 {
|
||||
parsed = parseXMLToolCalls(candidate)
|
||||
}
|
||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
||||
if len(parsed) == 0 {
|
||||
parsed = parseMarkupToolCalls(candidate)
|
||||
}
|
||||
if len(parsed) > 0 {
|
||||
result.SawToolCallSyntax = true
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
@@ -103,32 +121,32 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
||||
}
|
||||
if len(allowed) == 0 {
|
||||
rejectedSet := map[string]struct{}{}
|
||||
rejected := make([]string, 0, len(parsed))
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := rejectedSet[tc.Name]; ok {
|
||||
continue
|
||||
}
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
}
|
||||
rejected := make([]string, 0, len(rejectedSet))
|
||||
for name := range rejectedSet {
|
||||
rejected = append(rejected, name)
|
||||
rejected = append(rejected, tc.Name)
|
||||
}
|
||||
return nil, rejected
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(parsed))
|
||||
rejectedSet := map[string]struct{}{}
|
||||
rejected := make([]string, 0)
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
matchedName := ""
|
||||
if _, ok := allowed[tc.Name]; ok {
|
||||
matchedName = tc.Name
|
||||
} else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok {
|
||||
matchedName = canonical
|
||||
}
|
||||
matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
|
||||
if matchedName == "" {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
if _, ok := rejectedSet[tc.Name]; !ok {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
rejected = append(rejected, tc.Name)
|
||||
}
|
||||
continue
|
||||
}
|
||||
tc.Name = matchedName
|
||||
@@ -137,13 +155,34 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
rejected := make([]string, 0, len(rejectedSet))
|
||||
for name := range rejectedSet {
|
||||
rejected = append(rejected, name)
|
||||
}
|
||||
return out, rejected
|
||||
}
|
||||
|
||||
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
|
||||
if _, ok := allowed[name]; ok {
|
||||
return name
|
||||
}
|
||||
lower := strings.ToLower(strings.TrimSpace(name))
|
||||
if canonical, ok := allowedCanonical[lower]; ok {
|
||||
return canonical
|
||||
}
|
||||
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
|
||||
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
|
||||
if loose == "" {
|
||||
return ""
|
||||
}
|
||||
for candidateLower, canonical := range allowedCanonical {
|
||||
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
var decoded any
|
||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||
@@ -163,6 +202,14 @@ func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
return nil
|
||||
}
|
||||
|
||||
func looksLikeToolCallSyntax(text string) bool {
|
||||
lower := strings.ToLower(text)
|
||||
return strings.Contains(lower, "tool_calls") ||
|
||||
strings.Contains(lower, "<tool_call") ||
|
||||
strings.Contains(lower, "<function_call") ||
|
||||
strings.Contains(lower, "<invoke")
|
||||
}
|
||||
|
||||
func parseToolCallList(v any) []ParsedToolCall {
|
||||
items, ok := v.([]any)
|
||||
if !ok {
|
||||
|
||||
235
internal/util/toolcalls_parse_markup.go
Normal file
235
internal/util/toolcalls_parse_markup.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var xmlToolCallPattern = regexp.MustCompile(`(?is)<tool_call>\s*(.*?)\s*</tool_call>`)
|
||||
var functionCallPattern = regexp.MustCompile(`(?is)<function_call>\s*([^<]+?)\s*</function_call>`)
|
||||
var functionParamPattern = regexp.MustCompile(`(?is)<function\s+parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</function\s+parameter>`)
|
||||
var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?function_call>`)
|
||||
var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*</(?:[a-z0-9_]+:)?argument>`)
|
||||
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
|
||||
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
|
||||
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
|
||||
|
||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
matches := xmlToolCallPattern.FindAllString(text, -1)
|
||||
out := make([]ParsedToolCall, 0, len(matches)+1)
|
||||
for _, block := range matches {
|
||||
call, ok := parseSingleXMLToolCall(block)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, call)
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
if call, ok := parseFunctionCallTagStyle(text); ok {
|
||||
return []ParsedToolCall{call}
|
||||
}
|
||||
if calls := parseAntmlFunctionCallStyles(text); len(calls) > 0 {
|
||||
return calls
|
||||
}
|
||||
if call, ok := parseInvokeFunctionCallStyle(text); ok {
|
||||
return []ParsedToolCall{call}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
inner := strings.TrimSpace(block)
|
||||
inner = strings.TrimPrefix(inner, "<tool_call>")
|
||||
inner = strings.TrimSuffix(inner, "</tool_call>")
|
||||
inner = strings.TrimSpace(inner)
|
||||
if strings.HasPrefix(inner, "{") {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
|
||||
name := strings.TrimSpace(asString(payload["tool"]))
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(payload["tool_name"]))
|
||||
}
|
||||
if name != "" {
|
||||
input := map[string]any{}
|
||||
if params, ok := payload["params"].(map[string]any); ok {
|
||||
input = params
|
||||
} else if params, ok := payload["parameters"].(map[string]any); ok {
|
||||
input = params
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dec := xml.NewDecoder(strings.NewReader(block))
|
||||
name := ""
|
||||
params := map[string]any{}
|
||||
inParams := false
|
||||
inTool := false
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
switch t := tok.(type) {
|
||||
case xml.StartElement:
|
||||
tag := strings.ToLower(t.Name.Local)
|
||||
switch tag {
|
||||
case "tool":
|
||||
inTool = true
|
||||
for _, attr := range t.Attr {
|
||||
if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") && strings.TrimSpace(name) == "" {
|
||||
name = strings.TrimSpace(attr.Value)
|
||||
}
|
||||
}
|
||||
case "parameters":
|
||||
inParams = true
|
||||
case "tool_name", "name":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
name = strings.TrimSpace(v)
|
||||
}
|
||||
case "input", "arguments", "argument", "args", "params":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
|
||||
for k, vv := range parsed {
|
||||
params[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
if inParams || inTool {
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
case xml.EndElement:
|
||||
tag := strings.ToLower(t.Name.Local)
|
||||
if tag == "parameters" {
|
||||
inParams = false
|
||||
}
|
||||
if tag == "tool" {
|
||||
inTool = false
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||
}
|
||||
|
||||
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||
m := functionCallPattern.FindStringSubmatch(text)
|
||||
if len(m) < 2 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
input := map[string]any{}
|
||||
for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) {
|
||||
if len(pm) < 3 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(pm[1])
|
||||
val := strings.TrimSpace(pm[2])
|
||||
if key != "" {
|
||||
input[key] = val
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
func parseAntmlFunctionCallStyles(text string) []ParsedToolCall {
|
||||
matches := antmlFunctionCallPattern.FindAllStringSubmatch(text, -1)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(matches))
|
||||
for _, m := range matches {
|
||||
if call, ok := parseSingleAntmlFunctionCallMatch(m); ok {
|
||||
out = append(out, call)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := strings.TrimSpace(m[2])
|
||||
input := map[string]any{}
|
||||
if strings.HasPrefix(body, "{") {
|
||||
if err := json.Unmarshal([]byte(body), &input); err == nil {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 {
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
}
|
||||
for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) {
|
||||
if len(am) < 3 {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(am[1])
|
||||
v := strings.TrimSpace(am[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
m := invokeCallPattern.FindStringSubmatch(text)
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
input := map[string]any{}
|
||||
for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) {
|
||||
if len(pm) < 3 {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := strings.TrimSpace(pm[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
if len(input) == 0 {
|
||||
if argsRaw := findMarkupTagValue(m[2], toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag); argsRaw != "" {
|
||||
input = parseMarkupInput(argsRaw)
|
||||
} else if kv := parseMarkupKVObject(m[2]); len(kv) > 0 {
|
||||
input = kv
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
@@ -115,3 +115,167 @@ func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search_web"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "search_web" {
|
||||
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "read_file" {
|
||||
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
|
||||
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command><description>show cwd</description></parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDetailedMarksXMLToolCallSyntax(t *testing.T) {
|
||||
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`
|
||||
res := ParseToolCallsDetailed(text, []string{"bash"})
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||
}
|
||||
if len(res.Calls) != 1 {
|
||||
t.Fatalf("expected one parsed call, got %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsClaudeXMLJSONToolCall(t *testing.T) {
|
||||
text := `<tool_call>{"tool":"Bash","params":{"command":"pwd","description":"show cwd"}}</tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsFunctionCallTagStyle(t *testing.T) {
|
||||
text := `<function_call>Bash</function_call><function parameter name="command">ls -la</function parameter><function parameter name="description">list</function parameter>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "ls -la" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsAntmlFunctionCallStyle(t *testing.T) {
|
||||
text := `<antml:function_calls><antml:function_call name="Bash">{"command":"pwd","description":"x"}</antml:function_call></antml:function_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsAntmlArgumentStyle(t *testing.T) {
|
||||
text := `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument><antml:argument name="description">x</antml:argument></antml:function_call></antml:function_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
|
||||
text := `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter><parameter name="description">d</parameter></invoke></function_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) {
|
||||
text := `<tool_call><tool name="Bash"><command>pwd</command><description>show cwd</description></tool></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testing.T) {
|
||||
text := `<antml:function_calls><antml:function_call id="x" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" {
|
||||
t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsMultipleAntmlFunctionCalls(t *testing.T) {
|
||||
text := `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call><antml:function_call id="2" function="Read"><antml:parameters>{"file_path":"README.md"}</antml:parameters></antml:function_call></antml:function_calls>`
|
||||
calls := ParseToolCalls(text, []string{"bash", "read"})
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("expected 2 calls, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "bash" || calls[1].Name != "read" {
|
||||
t.Fatalf("expected canonical names [bash read], got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||
text := `<tool_call><name>read_file</function><arguments>{"path":"README.md"}</arguments></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected mismatched tags to be rejected, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
"calls": [],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": true,
|
||||
"rejectedToolNames": [
|
||||
"unknown_tool"
|
||||
]
|
||||
}
|
||||
@@ -6,5 +6,8 @@
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
"calls": [],
|
||||
"sawToolCallSyntax": false,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
13
tests/compat/expected/toolcalls_function_call_tag.json
Normal file
13
tests/compat/expected/toolcalls_function_call_tag.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
13
tests/compat/expected/toolcalls_invoke_attr.json
Normal file
13
tests/compat/expected/toolcalls_invoke_attr.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
13
tests/compat/expected/toolcalls_loose_normalize.json
Normal file
13
tests/compat/expected/toolcalls_loose_normalize.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
"calls": [],
|
||||
"sawToolCallSyntax": false,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
"calls": [],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
@@ -6,5 +6,8 @@
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
@@ -1,3 +1,8 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
"calls": [],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": true,
|
||||
"rejectedToolNames": [
|
||||
"unknown_tool"
|
||||
]
|
||||
}
|
||||
13
tests/compat/expected/toolcalls_xml_tool_call.json
Normal file
13
tests/compat/expected/toolcalls_xml_tool_call.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
],
|
||||
"sawToolCallSyntax": true,
|
||||
"rejectedByPolicy": false,
|
||||
"rejectedToolNames": []
|
||||
}
|
||||
6
tests/compat/fixtures/toolcalls/function_call_tag.json
Normal file
6
tests/compat/fixtures/toolcalls/function_call_tag.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"text": "<function_call><function>read_file</function><parameters>{\"path\":\"README.MD\"}</parameters></function_call>",
|
||||
"tool_names": [
|
||||
"read_file"
|
||||
]
|
||||
}
|
||||
6
tests/compat/fixtures/toolcalls/invoke_attr.json
Normal file
6
tests/compat/fixtures/toolcalls/invoke_attr.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"text": "<invoke name=\"read_file\"><argument>{\"path\":\"README.MD\"}</argument></invoke>",
|
||||
"tool_names": [
|
||||
"read_file"
|
||||
]
|
||||
}
|
||||
6
tests/compat/fixtures/toolcalls/loose_normalize.json
Normal file
6
tests/compat/fixtures/toolcalls/loose_normalize.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"text": "{\"tool_calls\":[{\"name\":\"read-file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
"tool_names": [
|
||||
"read_file"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"text": "{\"tool_calls\":[{\"name\":\"company.fs.read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
"tool_names": [
|
||||
"read_file"
|
||||
]
|
||||
}
|
||||
6
tests/compat/fixtures/toolcalls/xml_tool_call.json
Normal file
6
tests/compat/fixtures/toolcalls/xml_tool_call.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"text": "<tool_call><name>read_file</name><arguments>{\"path\":\"README.MD\"}</arguments></tool_call>",
|
||||
"tool_names": [
|
||||
"read_file"
|
||||
]
|
||||
}
|
||||
@@ -6,7 +6,7 @@ const fs = require('node:fs');
|
||||
const path = require('node:path');
|
||||
|
||||
const chatStream = require('../../api/chat-stream.js');
|
||||
const { parseToolCalls, parseStandaloneToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||
const { parseToolCallsDetailed, parseStandaloneToolCallsDetailed } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||
|
||||
const { parseChunkForContent, estimateTokens } = chatStream.__test;
|
||||
|
||||
@@ -44,9 +44,12 @@ test('js compat: toolcall fixtures', () => {
|
||||
const fixture = readJSON(path.join(fixtureDir, file));
|
||||
const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`));
|
||||
const mode = typeof fixture.mode === 'string' ? fixture.mode.trim().toLowerCase() : '';
|
||||
const parser = mode === 'standalone' ? parseStandaloneToolCalls : parseToolCalls;
|
||||
const parser = mode === 'standalone' ? parseStandaloneToolCallsDetailed : parseToolCallsDetailed;
|
||||
const got = parser(fixture.text, fixture.tool_names || []);
|
||||
assert.deepEqual(got, expected.calls, `${name}: calls mismatch`);
|
||||
assert.deepEqual(got.calls, expected.calls, `${name}: calls mismatch`);
|
||||
assert.equal(got.sawToolCallSyntax, expected.sawToolCallSyntax, `${name}: sawToolCallSyntax mismatch`);
|
||||
assert.equal(got.rejectedByPolicy, expected.rejectedByPolicy, `${name}: rejectedByPolicy mismatch`);
|
||||
assert.deepEqual(got.rejectedToolNames, expected.rejectedToolNames, `${name}: rejectedToolNames mismatch`);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -109,7 +109,23 @@ test('parseStandaloneToolCalls ignores fenced code block tool_call examples', ()
|
||||
assert.equal(calls.length, 0);
|
||||
});
|
||||
|
||||
test('sieve keeps late key convergence payload as plain text in strict mode', () => {
|
||||
|
||||
test('sieve emits tool_calls in the same chunk processing tick once payload is complete', () => {
|
||||
const state = createToolSieveState();
|
||||
const first = processToolSieveChunk(state, '{"', ['read_file']);
|
||||
const second = processToolSieveChunk(
|
||||
state,
|
||||
'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}',
|
||||
['read_file'],
|
||||
);
|
||||
const firstCalls = first.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
const secondCalls = second.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
assert.equal(firstCalls.length, 0);
|
||||
assert.equal(secondCalls.length, 1);
|
||||
assert.equal(secondCalls[0].name, 'read_file');
|
||||
});
|
||||
|
||||
test('sieve emits tool_calls when late key convergence forms a complete payload', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'{"',
|
||||
@@ -119,12 +135,11 @@ test('sieve keeps late key convergence payload as plain text in strict mode', ()
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
|
||||
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0);
|
||||
assert.equal(hasToolCall || hasToolDelta, false);
|
||||
assert.equal(leakedText.includes('{'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
assert.equal(finalCalls.length, 1);
|
||||
assert.equal(finalCalls[0].name, 'read_file');
|
||||
assert.equal(leakedText.includes('后置正文C。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => {
|
||||
@@ -249,3 +264,9 @@ test('formatOpenAIStreamToolCalls reuses ids with the same idStore', () => {
|
||||
assert.equal(second.length, 1);
|
||||
assert.equal(first[0].id, second[0].id);
|
||||
});
|
||||
|
||||
test('parseToolCalls rejects mismatched markup tags', () => {
|
||||
const payload = '<tool_call><name>read_file</function><arguments>{"path":"README.md"}</arguments></tool_call>';
|
||||
const calls = parseToolCalls(payload, ['read_file']);
|
||||
assert.equal(calls.length, 0);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user