mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-03 16:05:26 +08:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27eb73d48b | ||
|
|
685b5011e4 | ||
|
|
15e9eb3639 | ||
|
|
f18e6b9b11 | ||
|
|
40ebc8e942 | ||
|
|
fa3e6d040d | ||
|
|
458e4469e5 | ||
|
|
72c8e7e9f9 | ||
|
|
b9c8e90d98 | ||
|
|
36fcba1280 | ||
|
|
801b5abce3 | ||
|
|
206c3d5479 | ||
|
|
b2903c35ed | ||
|
|
b26dc8b7de | ||
|
|
63271aea8c | ||
|
|
516da04bcd | ||
|
|
9f7b671e5e |
@@ -152,7 +152,8 @@ OpenAI Chat / Responses 在标准化后、current input file 之前,会默认
|
||||
|
||||
工具调用正例现在优先示范官方 DSML 风格:`<|DSML|tool_calls>` → `<|DSML|invoke name="...">` → `<|DSML|parameter name="...">`。
|
||||
兼容层仍接受旧式纯 `<tool_calls>` wrapper,但提示词会优先要求模型输出官方 DSML 标签,并强调不能只输出 closing wrapper 而漏掉 opening tag。需要注意:这是“兼容 DSML 外壳,内部仍以 XML 解析语义为准”,不是原生 DSML 全链路实现;DSML 标签会在解析入口归一化回现有 XML 标签后继续走同一套 parser。
|
||||
数组参数使用 `<item>...</item>` 子节点表示;当某个参数体只包含 item 子节点时,Go / Node 解析器会把它还原成数组,避免 `questions` / `options` 这类 schema 中要求 array 的参数被误解析成 `{ "item": ... }` 对象。若模型把完整结构化 XML fragment 误包进 CDATA,兼容层会在保护 `content` / `command` 等原文字段的前提下,尝试把非原文字段中的 CDATA XML fragment 还原成 object / array。
|
||||
数组参数使用 `<item>...</item>` 子节点表示;当某个参数体只包含 item 子节点时,Go / Node 解析器会把它还原成数组,避免 `questions` / `options` 这类 schema 中要求 array 的参数被误解析成 `{ "item": ... }` 对象。若模型把完整结构化 XML fragment 误包进 CDATA,兼容层会在保护 `content` / `command` 等原文字段的前提下,尝试把非原文字段中的 CDATA XML fragment 还原成 object / array。不过,如果 CDATA 只是单个平面的 XML/HTML 标签,例如 `<b>urgent</b>` 这种行内标记,兼容层会保留原始字符串,不会强行升成 object / array;只有明显表示结构的 CDATA 片段,例如多兄弟节点、嵌套子节点或 `item` 列表,才会触发结构化恢复。
|
||||
在 assistant 最终回包阶段,如果某个 tool 参数在声明 schema 中明确是 `string`,兼容层会在把解析后的 `tool_calls` / `function_call` 重新序列化成 OpenAI / Responses / Claude 可见参数前,递归把该路径上的 number / bool / object / array 统一转成字符串;其中 object / array 会压成紧凑 JSON 字符串。这个保护只对 schema 明确声明为 string 的路径生效,不会改写本来就是 `number` / `boolean` / `object` / `array` 的参数。这样可以兼容 DeepSeek 输出了结构化片段、但上游客户端工具 schema 又严格要求字符串参数的场景(例如 `content`、`prompt`、`path`、`taskId` 等)。
|
||||
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
|
||||
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。
|
||||
|
||||
@@ -242,7 +243,7 @@ OpenAI 文件相关实现:
|
||||
|
||||
兼容层现在只保留 `current_input_file` 这一种拆分方式;旧的 `history_split` 已废弃,只保留为兼容旧配置的字段,不再参与请求处理。
|
||||
|
||||
- `current_input_file` 默认开启;它用于把“完整上下文”合并进隐藏上下文文件。当最新 user turn 的纯文本长度达到 `current_input_file.min_chars`(默认 `0`)时,兼容层会上传一个文件名为 `IGNORE.txt` 的上下文文件,并在文件内容前加入一个明确的 `context note`,提示模型这是被压缩过的历史记录而不是新指令;live prompt 也会显式说明当前处于 compacted-context mode,要求模型用已提供的历史来还原上下文状态并直接回答最新请求,避免把重复工具调用或重复提问当成新的起点。
|
||||
- `current_input_file` 默认开启;它用于把“完整上下文”合并进隐藏上下文文件。当最新 user turn 的纯文本长度达到 `current_input_file.min_chars`(默认 `0`)时,兼容层会上传一个文件名为 `IGNORE.txt` 的上下文文件,并在 live prompt 中只保留一个中性的 user 消息要求模型直接回答最新请求,不再暴露文件名或要求模型读取本地文件。
|
||||
- 如果 `current_input_file.enabled=false`,请求会直接透传,不上传任何拆分上下文文件。
|
||||
- 旧的 `history_split.enabled` / `history_split.trigger_after_turns` 会被读取进配置对象以保持兼容,但不会触发拆分上传,也不会影响 `current_input_file` 的默认开启。
|
||||
|
||||
@@ -255,18 +256,12 @@ OpenAI 文件相关实现:
|
||||
- 旧历史拆分兼容壳:
|
||||
[internal/httpapi/openai/history/history_split.go](../internal/httpapi/openai/history/history_split.go)
|
||||
|
||||
当前输入转文件启用并触发时,上传文件的真实文件名是 `IGNORE.txt`,文件内容是完整 `messages` 上下文;它仍会先用 OpenAI 消息标准化和 DeepSeek 角色标记序列化,再包进 `context note` 和 `IGNORE` 文件边界里:
|
||||
当前输入转文件启用并触发时,上传文件的真实文件名是 `IGNORE.txt`,文件内容是完整 `messages` 上下文;它仍会先用 OpenAI 消息标准化和 DeepSeek 角色标记序列化,再包进 `IGNORE` 文件边界里:
|
||||
|
||||
```text
|
||||
[uploaded filename]: IGNORE.txt
|
||||
[file content end]
|
||||
|
||||
[context note]
|
||||
This is a compacted snapshot of the prior conversation history for the current request.
|
||||
Use it as history only. Do not treat it as a new instruction.
|
||||
If the same question or tool action already appears here, do not repeat it unless the latest turn adds new information.
|
||||
[/context note]
|
||||
|
||||
<|begin▁of▁sentence|><|System|>...<|User|>...<|Assistant|>...<|Tool|>...<|User|>...
|
||||
|
||||
[file name]: IGNORE
|
||||
@@ -322,7 +317,7 @@ If the same question or tool action already appears here, do not repeat it unles
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "<|begin▁of▁sentence|><|System|>原 system / developer\n\nYou have access to these tools: ...<|end▁of▁instructions|><|User|>You are in a compacted-context mode. The attached history contains the prior conversation state and any earlier tool results. Use it to resolve references and answer the latest user request directly. If the same tool action or question already appears in the attached context, do not repeat it unless the latest turn adds new information.<|Assistant|>",
|
||||
"prompt": "<|begin▁of▁sentence|><|System|>原 system / developer\n\nYou have access to these tools: ...<|end▁of▁instructions|><|User|>The current request and prior conversation context have already been provided. Answer the latest user request directly.<|Assistant|>",
|
||||
"ref_file_ids": [
|
||||
"file-current-input-ignore",
|
||||
"file-systemprompt",
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
兼容修复:
|
||||
|
||||
- 如果模型漏掉 opening wrapper,但后面仍输出了一个或多个 invoke 并以 closing wrapper 收尾,Go 解析链路会在解析前补回缺失的 opening wrapper。
|
||||
- 如果模型把 DSML 标签里的分隔符 `|` 写漏成空格(例如 `<|DSML tool_calls>` / `<|DSML invoke>` / `<|DSML parameter>`,或无 leading pipe 的 `<DSML tool_calls>` 形态),或把 `DSML` 与工具标签名直接黏连(例如 `<DSMLtool_calls>` / `<DSMLinvoke>` / `<DSMLparameter>`),Go / Node 会在固定工具标签名范围内归一化;相似但非工具标签名(如 `tool_calls_extra`)仍按普通文本处理。
|
||||
- 如果模型把 DSML 标签里的分隔符 `|` 写漏成空格(例如 `<|DSML tool_calls>` / `<|DSML invoke>` / `<|DSML parameter>`,或无 leading pipe 的 `<DSML tool_calls>` 形态),或把 `DSML` 与工具标签名直接黏连(例如 `<DSMLtool_calls>` / `<DSMLinvoke>` / `<DSMLparameter>`),或把最前面的 pipe 误写成全宽竖线(例如 `<|DSML|tool_calls>` / `<|DSML|invoke>` / `<|DSML|parameter>`),Go / Node 会在固定工具标签名范围内归一化;相似但非工具标签名(如 `tool_calls_extra`)仍按普通文本处理。
|
||||
- 这是一个针对常见模型失误的窄修复,不改变推荐输出格式;prompt 仍要求模型直接输出完整 DSML 外壳。
|
||||
- 裸 `<invoke ...>` / `<parameter ...>` 不会被当成“已支持的工具语法”;只有 `tool_calls` wrapper 或可修复的缺失 opening wrapper 才会进入工具调用路径。
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
|
||||
在流式链路中(Go / Node 一致):
|
||||
|
||||
- DSML `<|DSML|tool_calls>` wrapper、兼容变体(`<dsml|tool_calls>`、`<|tool_calls>`、`<|tool_calls>`)、窄容错空格分隔形态(如 `<|DSML tool_calls>`)、黏连形态(如 `<DSMLtool_calls>`)和 canonical `<tool_calls>` wrapper 都会进入结构化捕获
|
||||
- DSML `<|DSML|tool_calls>` wrapper、兼容变体(`<dsml|tool_calls>`、`<|tool_calls>`、`<|tool_calls>`、`<|DSML|tool_calls>`)、窄容错空格分隔形态(如 `<|DSML tool_calls>`)、黏连形态(如 `<DSMLtool_calls>`)和 canonical `<tool_calls>` wrapper 都会进入结构化捕获
|
||||
- 如果流里直接从 invoke 开始,但后面补上了 closing wrapper,Go 流式筛分也会按缺失 opening wrapper 的修复路径尝试恢复
|
||||
- 已识别成功的工具调用不会再次回流到普通文本
|
||||
- 不符合新格式的块不会执行,并继续按原样文本透传
|
||||
@@ -64,7 +64,7 @@
|
||||
|
||||
另外,`<parameter>` 的值如果本身是合法 JSON 字面量,也会按结构化值解析,而不是一律保留为字符串。例如 `123`、`true`、`null`、`[1,2]`、`{"a":1}` 都会还原成对应的 number / boolean / null / array / object。
|
||||
结构化 XML 参数也会还原为 JSON 结构:如果参数体只包含一个或多个 `<item>...</item>` 子节点,会输出数组;嵌套对象里的 item-only 字段也同样按数组处理。例如 `<parameter name="questions"><item><question>...</question></item></parameter>` 会输出 `{"questions":[{"question":"..."}]}`,而不是 `{"questions":{"item":...}}`。
|
||||
如果模型误把完整结构化 XML fragment 放进 CDATA,Go / Node 会先保护明显的原文字段(如 `content` / `command` / `prompt` / `old_string` / `new_string`),其余参数会尝试把 CDATA 内的完整 XML fragment 还原成 object / array;常见的 `<br>` 分隔符会按换行归一化后再解析。
|
||||
如果模型误把完整结构化 XML fragment 放进 CDATA,Go / Node 会先保护明显的原文字段(如 `content` / `command` / `prompt` / `old_string` / `new_string`),其余参数会尝试把 CDATA 内的完整 XML fragment 还原成 object / array;常见的 `<br>` 分隔符会按换行归一化后再解析。但如果 CDATA 只是单个平面的 XML/HTML 标签,例如 `<b>urgent</b>` 这种行内标记,兼容层会把它保留为原始字符串,而不会强行升成 object / array;只有明显表示结构的 CDATA 片段,例如多兄弟节点、嵌套子节点或 `item` 列表,才会触发结构化恢复。
|
||||
|
||||
## 4) 输出结构
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
|
||||
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
||||
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
||||
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
}
|
||||
|
||||
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
|
||||
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
@@ -19,7 +19,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, toolsRaw)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,19 +9,19 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
|
||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||
// as executable tool calls.
|
||||
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
||||
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
||||
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
}
|
||||
|
||||
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
|
||||
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
output = append(output, toResponsesFunctionCallItems(detected, toolsRaw)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
@@ -74,12 +74,13 @@ func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking,
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall) []any {
|
||||
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall, toolsRaw any) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(toolCalls, toolsRaw)
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
for _, tc := range normalizedCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/toolcall"
|
||||
)
|
||||
|
||||
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
@@ -13,6 +16,7 @@ func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
nil,
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
@@ -42,6 +46,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
"internal thinking content",
|
||||
"",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
@@ -75,6 +80,7 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
|
||||
`<tool_calls><invoke name="search"><parameter name="q">from-thinking</parameter></invoke></tool_calls>`,
|
||||
"",
|
||||
[]string{"search"},
|
||||
nil,
|
||||
)
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
@@ -86,3 +92,88 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
|
||||
t.Fatalf("expected function_call output, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildChatCompletionWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
obj := BuildChatCompletionWithToolCalls(
|
||||
"chat_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"",
|
||||
[]toolcall.ParsedToolCall{{
|
||||
Name: "Write",
|
||||
Input: map[string]any{
|
||||
"content": map[string]any{"message": "hi"},
|
||||
"taskId": 1,
|
||||
},
|
||||
}},
|
||||
toolsRaw,
|
||||
)
|
||||
choices, _ := obj["choices"].([]map[string]any)
|
||||
message, _ := choices[0]["message"].(map[string]any)
|
||||
toolCalls, _ := message["tool_calls"].([]map[string]any)
|
||||
fn, _ := toolCalls[0]["function"].(map[string]any)
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(fn["arguments"].(string)), &args); err != nil {
|
||||
t.Fatalf("decode arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected taskId stringified by schema, got %#v", args["taskId"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
obj := BuildResponseObjectWithToolCalls(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"",
|
||||
[]toolcall.ParsedToolCall{{
|
||||
Name: "Write",
|
||||
Input: map[string]any{"content": []any{"a", 1}},
|
||||
}},
|
||||
toolsRaw,
|
||||
)
|
||||
output, _ := obj["output"].([]any)
|
||||
first, _ := output[0].(map[string]any)
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(first["arguments"].(string)), &args); err != nil {
|
||||
t.Fatalf("decode response arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `["a",1]` {
|
||||
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, session)
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, nil, session)
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
@@ -317,9 +317,9 @@ func TestChatCompletionsCurrentInputFilePersistsNeutralPrompt(t *testing.T) {
|
||||
t.Fatalf("expected IGNORE.txt upload, got %q", ds.uploadCalls[0].Filename)
|
||||
}
|
||||
if len(full.Messages) != 1 {
|
||||
t.Fatalf("expected compacted-context prompt to be the only persisted message, got %#v", full.Messages)
|
||||
t.Fatalf("expected neutral prompt to be the only persisted message, got %#v", full.Messages)
|
||||
}
|
||||
if !strings.Contains(full.Messages[0].Content, promptcompat.BuildOpenAICurrentInputContextPrompt()) {
|
||||
t.Fatalf("expected compacted-context prompt to be persisted, got %#v", full.Messages[0])
|
||||
if !strings.Contains(full.Messages[0].Content, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected neutral prompt to be persisted, got %#v", full.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type chatStreamRuntime struct {
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
@@ -61,6 +62,7 @@ func newChatStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
) *chatStreamRuntime {
|
||||
@@ -73,6 +75,7 @@ func newChatStreamRuntime(
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
@@ -142,7 +145,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -164,7 +167,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
@@ -320,7 +323,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
|
||||
@@ -26,14 +26,14 @@ type chatNonStreamResult struct {
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
attempts := 0
|
||||
currentResp := resp
|
||||
usagePrompt := finalPrompt
|
||||
accumulatedThinking := ""
|
||||
accumulatedToolDetectionThinking := ""
|
||||
for {
|
||||
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||
detected := detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
||||
result.detectedCalls = len(detected.Calls)
|
||||
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls)
|
||||
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls, toolsRaw)
|
||||
result.finishReason = chatFinishReason(result.body)
|
||||
if !shouldRetryChatNonStream(result, attempts) {
|
||||
h.finishChatNonStreamResult(w, result, attempts, usagePrompt, historySession)
|
||||
@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (chatNonStreamResult, bool) {
|
||||
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (chatNonStreamResult, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -88,7 +88,7 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
|
||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||
}
|
||||
detected := detectAssistantToolCalls(finalText, finalThinking, finalToolDetectionThinking, toolNames)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
return chatNonStreamResult{
|
||||
thinking: finalThinking,
|
||||
toolDetectionThinking: finalToolDetectionThinking,
|
||||
@@ -139,8 +139,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
|
||||
strings.TrimSpace(result.text) == ""
|
||||
}
|
||||
|
||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, historySession)
|
||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -182,7 +182,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -207,7 +207,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
|
||||
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames,
|
||||
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
|
||||
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||
)
|
||||
return streamRuntime, initialType, true
|
||||
|
||||
@@ -144,8 +144,8 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
|
||||
return shared.FilterIncrementalToolCallDeltasByAllowed(deltas, seenNames)
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids)
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
|
||||
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids, toolsRaw)
|
||||
}
|
||||
|
||||
func detectAssistantToolCalls(text, exposedThinking, detectionThinking string, toolNames []string) toolcall.ToolCallParseResult {
|
||||
|
||||
@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -148,7 +148,7 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -176,7 +176,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
||||
writeUpstreamEmptyOutputError(w, finalText, finalThinking, result.ContentFilter)
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
finishReason := "stop"
|
||||
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
||||
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
||||
@@ -189,7 +189,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -230,6 +230,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
searchEnabled,
|
||||
stripReferenceMarkers,
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil)
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -150,7 +150,7 @@ func TestHandleNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
|
||||
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestHandleNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpty(t *test
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -211,7 +211,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -248,7 +248,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -282,7 +282,7 @@ func TestHandleStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstreamIntercep
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -319,7 +319,7 @@ func TestHandleStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *testing.T)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -353,7 +353,7 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -390,3 +390,64 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
|
||||
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamCoercesSchemaDeclaredStringArgumentsOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
line := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
|
||||
return "data: " + string(b)
|
||||
}
|
||||
resp := makeSSEHTTPResponse(
|
||||
line(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-string-protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(asString(fn["arguments"])), &args); err != nil {
|
||||
t.Fatalf("decode streamed tool arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected streamed content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected streamed taskId stringified by schema, got %#v", args["taskId"])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected at least one streamed tool call delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"ds2api/internal/auth"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
"ds2api/internal/promptcompat"
|
||||
)
|
||||
|
||||
func TestIsVercelStreamPrepareRequest(t *testing.T) {
|
||||
@@ -131,8 +130,8 @@ func TestHandleVercelStreamPrepareAppliesCurrentInputFile(t *testing.T) {
|
||||
t.Fatalf("expected payload object, got %#v", body["payload"])
|
||||
}
|
||||
promptText, _ := payload["prompt"].(string)
|
||||
if !strings.Contains(promptText, promptcompat.BuildOpenAICurrentInputContextPrompt()) {
|
||||
t.Fatalf("expected compacted-context prompt, got %s", promptText)
|
||||
if !strings.Contains(promptText, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected neutral prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected original turns hidden from prompt, got %s", promptText)
|
||||
|
||||
@@ -26,3 +26,31 @@ func TestReplaceCitationMarkersWithLinksKeepsUnknownIndex(t *testing.T) {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceCitationMarkersWithLinksSupportsReferenceMarker(t *testing.T) {
|
||||
raw := "新闻摘要[reference:1],详情[reference:2]。"
|
||||
links := map[int]string{
|
||||
1: "https://example.com/r1",
|
||||
2: "https://example.com/r2",
|
||||
}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "新闻摘要[1](https://example.com/r1),详情[2](https://example.com/r2)。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceCitationMarkersWithLinksSupportsReferenceZeroBased(t *testing.T) {
|
||||
raw := "来源[reference:0] 与 [reference:1]。"
|
||||
links := map[int]string{
|
||||
1: "https://example.com/first",
|
||||
2: "https://example.com/second",
|
||||
}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "来源[0](https://example.com/first) 与 [1](https://example.com/second)。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,5 +84,5 @@ func latestUserInputForFile(messages []any) (int, string) {
|
||||
}
|
||||
|
||||
func currentInputFilePrompt() string {
|
||||
return promptcompat.BuildOpenAICurrentInputContextPrompt()
|
||||
return "The current request and prior conversation context have already been provided. Answer the latest user request directly."
|
||||
}
|
||||
|
||||
@@ -67,9 +67,6 @@ func TestBuildOpenAICurrentInputContextTranscriptUsesInjectedFileWrapper(t *test
|
||||
if !strings.HasPrefix(transcript, "[file content end]\n\n") {
|
||||
t.Fatalf("expected injected file wrapper prefix, got %q", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "[context note]") || !strings.Contains(transcript, "compacted snapshot of the prior conversation history") {
|
||||
t.Fatalf("expected compacted context note in transcript, got %q", transcript)
|
||||
}
|
||||
if !strings.Contains(transcript, "<|begin▁of▁sentence|>") {
|
||||
t.Fatalf("expected serialized conversation markers, got %q", transcript)
|
||||
}
|
||||
@@ -299,8 +296,8 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
|
||||
if strings.Contains(out.FinalPrompt, "CURRENT_USER_INPUT.txt") || strings.Contains(out.FinalPrompt, "IGNORE.txt") || strings.Contains(out.FinalPrompt, "Read that file") {
|
||||
t.Fatalf("expected live prompt not to instruct file reads, got %s", out.FinalPrompt)
|
||||
}
|
||||
if !strings.Contains(out.FinalPrompt, promptcompat.BuildOpenAICurrentInputContextPrompt()) {
|
||||
t.Fatalf("expected compacted-context instruction in live prompt, got %s", out.FinalPrompt)
|
||||
if !strings.Contains(out.FinalPrompt, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected neutral continuation instruction in live prompt, got %s", out.FinalPrompt)
|
||||
}
|
||||
if len(out.RefFileIDs) != 1 || out.RefFileIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("expected current input file id in ref_file_ids, got %#v", out.RefFileIDs)
|
||||
@@ -348,10 +345,10 @@ func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") || strings.Contains(out.FinalPrompt, "CURRENT_USER_INPUT.txt") || strings.Contains(out.FinalPrompt, "IGNORE.txt") || strings.Contains(out.FinalPrompt, "Read that file") {
|
||||
t.Fatalf("expected live prompt to stay in compacted-context mode, got %s", out.FinalPrompt)
|
||||
t.Fatalf("expected live prompt to use only a neutral continuation instruction, got %s", out.FinalPrompt)
|
||||
}
|
||||
if !strings.Contains(out.FinalPrompt, promptcompat.BuildOpenAICurrentInputContextPrompt()) {
|
||||
t.Fatalf("expected compacted-context instruction in live prompt, got %s", out.FinalPrompt)
|
||||
if !strings.Contains(out.FinalPrompt, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected neutral continuation instruction in live prompt, got %s", out.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -431,8 +428,8 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, promptcompat.BuildOpenAICurrentInputContextPrompt()) {
|
||||
t.Fatalf("expected compacted-context prompt, got %s", promptText)
|
||||
if !strings.Contains(promptText, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected neutral completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
|
||||
@@ -477,8 +474,8 @@ func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, promptcompat.BuildOpenAICurrentInputContextPrompt()) {
|
||||
t.Fatalf("expected compacted-context prompt, got %s", promptText)
|
||||
if !strings.Contains(promptText, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected neutral completion prompt, got %s", promptText)
|
||||
}
|
||||
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
|
||||
@@ -613,7 +610,7 @@ func TestCurrentInputFileWorksAcrossAutoDeleteModes(t *testing.T) {
|
||||
t.Fatalf("expected completion payload for mode=%s", mode)
|
||||
}
|
||||
promptText, _ := ds.completionReq["prompt"].(string)
|
||||
if !strings.Contains(promptText, promptcompat.BuildOpenAICurrentInputContextPrompt()) || strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
|
||||
if !strings.Contains(promptText, "Answer the latest user request directly.") || strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("unexpected prompt for mode=%s: %s", mode, promptText)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -27,14 +27,14 @@ type responsesNonStreamResult struct {
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
attempts := 0
|
||||
currentResp := resp
|
||||
usagePrompt := finalPrompt
|
||||
accumulatedThinking := ""
|
||||
accumulatedToolDetectionThinking := ""
|
||||
for {
|
||||
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
||||
result.thinking = accumulatedThinking
|
||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||
result.parsed = detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
||||
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls)
|
||||
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
|
||||
|
||||
if !shouldRetryResponsesNonStream(result, attempts) {
|
||||
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
|
||||
@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (responsesNonStreamResult, bool) {
|
||||
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -84,7 +84,7 @@ func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *
|
||||
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
||||
}
|
||||
textParsed := detectAssistantToolCalls(sanitizedText, sanitizedThinking, toolDetectionThinking, toolNames)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||
return responsesNonStreamResult{
|
||||
thinking: sanitizedThinking,
|
||||
toolDetectionThinking: toolDetectionThinking,
|
||||
@@ -123,8 +123,8 @@ func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int
|
||||
strings.TrimSpace(result.text) == ""
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolChoice, traceID)
|
||||
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
|
||||
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -184,7 +184,7 @@ func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *htt
|
||||
}
|
||||
streamRuntime := newResponsesStreamRuntime(
|
||||
w, rc, canFlush, responseID, model, finalPrompt, thinkingEnabled, searchEnabled,
|
||||
h.compatStripReferenceMarkers(), toolNames, len(toolNames) > 0,
|
||||
h.compatStripReferenceMarkers(), toolNames, toolsRaw, len(toolNames) > 0,
|
||||
h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||
toolChoice, traceID, func(obj map[string]any) {
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
|
||||
@@ -115,13 +115,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -148,12 +148,12 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -186,6 +186,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
searchEnabled,
|
||||
stripReferenceMarkers,
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
toolChoice,
|
||||
|
||||
@@ -22,6 +22,7 @@ type responsesStreamRuntime struct {
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
traceID string
|
||||
toolChoice promptcompat.ToolChoicePolicy
|
||||
|
||||
@@ -72,6 +73,7 @@ func newResponsesStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
toolChoice promptcompat.ToolChoicePolicy,
|
||||
@@ -89,6 +91,7 @@ func newResponsesStreamRuntime(
|
||||
searchEnabled: searchEnabled,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
|
||||
@@ -220,7 +220,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolstream
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []toolcall.ParsedToolCall) {
|
||||
for idx, tc := range calls {
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
|
||||
for idx, tc := range normalizedCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -109,7 +109,8 @@ func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, fin
|
||||
}
|
||||
}
|
||||
|
||||
for idx, tc := range calls {
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
|
||||
for idx, tc := range normalizedCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
@@ -57,7 +57,7 @@ func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testin
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_text.done") {
|
||||
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||
@@ -91,7 +91,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
|
||||
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
||||
@@ -130,7 +130,7 @@ func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
@@ -183,7 +183,7 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
Mode: promptcompat.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
@@ -213,7 +213,7 @@ func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
@@ -251,7 +251,7 @@ func TestHandleResponsesStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstrea
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
@@ -288,7 +288,7 @@ func TestHandleResponsesStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *t
|
||||
Mode: promptcompat.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.reasoning.delta") {
|
||||
@@ -317,7 +317,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -344,7 +344,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayloadWhe
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, nil, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -366,7 +366,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -410,7 +410,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testin
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func TestHandleResponsesNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testi
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -462,7 +462,7 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
|
||||
Mode: promptcompat.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -480,6 +480,53 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
streamBody := sseLine(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_string_protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
payload, ok := extractSSEEventPayload(rec.Body.String(), "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.function_call_arguments.done payload, body=%s", rec.Body.String())
|
||||
}
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(asString(payload["arguments"])), &args); err != nil {
|
||||
t.Fatalf("decode streamed response arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected response taskId stringified by schema, got %#v", args["taskId"])
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
|
||||
@@ -7,22 +7,27 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var citationMarkerPattern = regexp.MustCompile(`(?i)\[citation:\s*(\d+)\]`)
|
||||
var citationMarkerPattern = regexp.MustCompile(`(?i)\[(citation|reference):\s*(\d+)\]`)
|
||||
|
||||
func ReplaceCitationMarkersWithLinks(text string, links map[int]string) string {
|
||||
if strings.TrimSpace(text) == "" || len(links) == 0 {
|
||||
return text
|
||||
}
|
||||
zeroBased := strings.Contains(strings.ToLower(text), "[reference:0]")
|
||||
return citationMarkerPattern.ReplaceAllStringFunc(text, func(match string) string {
|
||||
sub := citationMarkerPattern.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
if len(sub) < 3 {
|
||||
return match
|
||||
}
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(sub[1]))
|
||||
if err != nil || idx <= 0 {
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(sub[2]))
|
||||
if err != nil || idx < 0 {
|
||||
return match
|
||||
}
|
||||
url := strings.TrimSpace(links[idx])
|
||||
lookupIdx := idx
|
||||
if zeroBased {
|
||||
lookupIdx = idx + 1
|
||||
}
|
||||
url := strings.TrimSpace(links[lookupIdx])
|
||||
if url == "" {
|
||||
return match
|
||||
}
|
||||
|
||||
@@ -70,12 +70,13 @@ func FilterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
for i, c := range normalizedCalls {
|
||||
callID := ""
|
||||
if ids != nil {
|
||||
callID = strings.TrimSpace(ids[i])
|
||||
|
||||
@@ -530,6 +530,7 @@ function findPartialToolMarkupStart(text) {
|
||||
'<|tool_calls', '<|invoke', '<|parameter',
|
||||
'<|tool_calls', '<|invoke', '<|parameter',
|
||||
'<|dsml|tool_calls', '<|dsml|invoke', '<|dsml|parameter',
|
||||
'<|dsml|tool_calls', '<|dsml|invoke', '<|dsml|parameter',
|
||||
'<dsmltool_calls', '<dsmlinvoke', '<dsmlparameter',
|
||||
'<dsml tool_calls', '<dsml invoke', '<dsml parameter',
|
||||
'<dsml|tool_calls', '<dsml|invoke', '<dsml|parameter',
|
||||
@@ -812,6 +813,9 @@ function parseStructuredCDATAParameterValue(paramName, raw) {
|
||||
if (!normalized.includes('<') || !normalized.includes('>')) {
|
||||
return { ok: false, value: null };
|
||||
}
|
||||
if (!cdataFragmentLooksExplicitlyStructured(normalized)) {
|
||||
return { ok: false, value: null };
|
||||
}
|
||||
const parsed = parseMarkupInput(normalized);
|
||||
if (Array.isArray(parsed)) {
|
||||
return { ok: true, value: parsed };
|
||||
@@ -826,6 +830,21 @@ function normalizeCDATAForStructuredParse(raw) {
|
||||
return unescapeHtml(toStringSafe(raw).replace(/<br\s*\/?>/gi, '\n').trim());
|
||||
}
|
||||
|
||||
function cdataFragmentLooksExplicitlyStructured(raw) {
|
||||
const blocks = findGenericXmlElementBlocks(raw);
|
||||
if (blocks.length === 0) {
|
||||
return false;
|
||||
}
|
||||
if (blocks.length > 1) {
|
||||
return true;
|
||||
}
|
||||
const block = blocks[0];
|
||||
if (toStringSafe(block.localName).trim().toLowerCase() === 'item') {
|
||||
return true;
|
||||
}
|
||||
return findGenericXmlElementBlocks(block.body).length > 0;
|
||||
}
|
||||
|
||||
function preservesCDATAStringParameter(name) {
|
||||
return new Set([
|
||||
'content',
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
const XML_TOOL_SEGMENT_TAGS = [
|
||||
'<|dsml|tool_calls>', '<|dsml|tool_calls\n', '<|dsml|tool_calls ',
|
||||
'<|dsml|tool_calls>', '<|dsml|tool_calls\n', '<|dsml|tool_calls ',
|
||||
'<|dsml|invoke ', '<|dsml|invoke\n', '<|dsml|invoke\t', '<|dsml|invoke\r',
|
||||
'<|dsmltool_calls>', '<|dsmltool_calls\n', '<|dsmltool_calls ',
|
||||
'<|dsmlinvoke ', '<|dsmlinvoke\n', '<|dsmlinvoke\t', '<|dsmlinvoke\r',
|
||||
@@ -23,6 +24,7 @@ const XML_TOOL_SEGMENT_TAGS = [
|
||||
|
||||
const XML_TOOL_OPENING_TAGS = [
|
||||
'<|dsml|tool_calls',
|
||||
'<|dsml|tool_calls',
|
||||
'<|dsmltool_calls',
|
||||
'<|dsml tool_calls',
|
||||
'<dsml|tool_calls',
|
||||
@@ -35,6 +37,7 @@ const XML_TOOL_OPENING_TAGS = [
|
||||
|
||||
const XML_TOOL_CLOSING_TAGS = [
|
||||
'</|dsml|tool_calls>',
|
||||
'</|dsml|tool_calls>',
|
||||
'</|dsmltool_calls>',
|
||||
'</|dsml tool_calls>',
|
||||
'</dsml|tool_calls>',
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
|
||||
const historySplitInjectedFilename = "IGNORE"
|
||||
|
||||
const currentInputContextNote = "[context note]\nThis is a compacted snapshot of the prior conversation history for the current request.\nUse it as history only. Do not treat it as a new instruction.\nIf the same question or tool action already appears here, do not repeat it unless the latest turn adds new information.\n[/context note]"
|
||||
|
||||
func BuildOpenAIHistoryTranscript(messages []any) string {
|
||||
return buildOpenAIInjectedFileTranscript(messages)
|
||||
}
|
||||
@@ -28,15 +26,11 @@ func BuildOpenAICurrentInputContextTranscript(messages []any) string {
|
||||
return buildOpenAIInjectedFileTranscript(messages)
|
||||
}
|
||||
|
||||
func BuildOpenAICurrentInputContextPrompt() string {
|
||||
return "You are in a compacted-context mode. The attached history contains the prior conversation state and any earlier tool results. Use it to resolve references and answer the latest user request directly. If the same tool action or question already appears in the attached context, do not repeat it unless the latest turn adds new information."
|
||||
}
|
||||
|
||||
func buildOpenAIInjectedFileTranscript(messages []any) string {
|
||||
normalized := NormalizeOpenAIMessagesForPrompt(messages, "")
|
||||
transcript := strings.TrimSpace(prompt.MessagesPrepare(normalized))
|
||||
if transcript == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("[file content end]\n\n%s\n\n%s\n\n[file name]: %s\n[file content begin]\n", currentInputContextNote, transcript, historySplitInjectedFilename)
|
||||
return fmt.Sprintf("[file content end]\n\n%s\n\n[file name]: %s\n[file content begin]\n", transcript, historySplitInjectedFilename)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
|
||||
{Name: "search", Input: map[string]any{"q": "test"}},
|
||||
})
|
||||
}, nil)
|
||||
if len(formatted) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(formatted))
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
|
||||
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
for _, c := range normalized {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
@@ -23,9 +24,10 @@ func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
|
||||
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
for i, c := range normalized {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
|
||||
@@ -2,6 +2,7 @@ package toolcall
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"html"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -350,6 +351,9 @@ func parseStructuredCDATAParameterValue(paramName, raw string) (any, bool) {
|
||||
if !strings.Contains(normalized, "<") || !strings.Contains(normalized, ">") {
|
||||
return nil, false
|
||||
}
|
||||
if !cdataFragmentLooksExplicitlyStructured(normalized) {
|
||||
return nil, false
|
||||
}
|
||||
parsed, ok := parseXMLFragmentValue(normalized)
|
||||
if !ok {
|
||||
return nil, false
|
||||
@@ -375,6 +379,65 @@ func normalizeCDATAForStructuredParse(raw string) string {
|
||||
return html.UnescapeString(strings.TrimSpace(normalized))
|
||||
}
|
||||
|
||||
// Preserve flat CDATA fragments as strings. Only recover structure when the
|
||||
// fragment clearly encodes a data shape: multiple sibling elements, nested
|
||||
// child elements, or an explicit item list.
|
||||
func cdataFragmentLooksExplicitlyStructured(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
dec := xml.NewDecoder(strings.NewReader("<root>" + trimmed + "</root>"))
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
start, ok := tok.(xml.StartElement)
|
||||
if !ok || !strings.EqualFold(start.Name.Local, "root") {
|
||||
return false
|
||||
}
|
||||
|
||||
depth := 0
|
||||
directChildren := 0
|
||||
firstChildName := ""
|
||||
firstChildHasNested := false
|
||||
|
||||
for {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch t := tok.(type) {
|
||||
case xml.StartElement:
|
||||
if depth == 0 {
|
||||
directChildren++
|
||||
if directChildren == 1 {
|
||||
firstChildName = strings.ToLower(strings.TrimSpace(t.Name.Local))
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
} else if directChildren == 1 && depth == 1 {
|
||||
firstChildHasNested = true
|
||||
}
|
||||
depth++
|
||||
case xml.EndElement:
|
||||
if strings.EqualFold(t.Name.Local, "root") {
|
||||
if directChildren != 1 {
|
||||
return false
|
||||
}
|
||||
if firstChildName == "item" {
|
||||
return true
|
||||
}
|
||||
return firstChildHasNested
|
||||
}
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func preservesCDATAStringParameter(name string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(name)) {
|
||||
case "content", "file_content", "text", "prompt", "query", "command", "cmd", "script", "code", "old_string", "new_string", "pattern", "path", "file_path":
|
||||
|
||||
266
internal/toolcall/toolcalls_schema_normalize.go
Normal file
266
internal/toolcall/toolcalls_schema_normalize.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func NormalizeParsedToolCallsForSchemas(calls []ParsedToolCall, toolsRaw any) []ParsedToolCall {
|
||||
if len(calls) == 0 {
|
||||
return calls
|
||||
}
|
||||
schemas := buildToolSchemaIndex(toolsRaw)
|
||||
if len(schemas) == 0 {
|
||||
return calls
|
||||
}
|
||||
|
||||
var changedAny bool
|
||||
out := make([]ParsedToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
out[i] = call
|
||||
schema, ok := schemas[strings.ToLower(strings.TrimSpace(call.Name))]
|
||||
if !ok || call.Input == nil {
|
||||
continue
|
||||
}
|
||||
normalized, changed := normalizeToolValueWithSchema(call.Input, schema)
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
changedAny = true
|
||||
if input, ok := normalized.(map[string]any); ok {
|
||||
out[i].Input = input
|
||||
}
|
||||
}
|
||||
if !changedAny {
|
||||
return calls
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildToolSchemaIndex(toolsRaw any) map[string]any {
|
||||
tools, ok := toolsRaw.([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(tools))
|
||||
for _, item := range tools {
|
||||
tool, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, schema := extractToolNameAndSchema(tool)
|
||||
if name == "" || schema == nil {
|
||||
continue
|
||||
}
|
||||
out[strings.ToLower(name)] = schema
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractToolNameAndSchema(tool map[string]any) (string, any) {
|
||||
name := strings.TrimSpace(asStringValue(tool["name"]))
|
||||
schema := tool["parameters"]
|
||||
if schema == nil {
|
||||
schema = tool["input_schema"]
|
||||
}
|
||||
if fn, ok := tool["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asStringValue(fn["name"]))
|
||||
}
|
||||
if schema == nil {
|
||||
schema = fn["parameters"]
|
||||
}
|
||||
if schema == nil {
|
||||
schema = fn["input_schema"]
|
||||
}
|
||||
}
|
||||
return name, schema
|
||||
}
|
||||
|
||||
func normalizeToolValueWithSchema(value any, schema any) (any, bool) {
|
||||
if value == nil || schema == nil {
|
||||
return value, false
|
||||
}
|
||||
schemaMap, ok := schema.(map[string]any)
|
||||
if !ok || len(schemaMap) == 0 {
|
||||
return value, false
|
||||
}
|
||||
if shouldCoerceSchemaToString(schemaMap) {
|
||||
return stringifySchemaValue(value)
|
||||
}
|
||||
if looksLikeObjectSchema(schemaMap) {
|
||||
obj, ok := value.(map[string]any)
|
||||
if !ok || len(obj) == 0 {
|
||||
return value, false
|
||||
}
|
||||
properties, _ := schemaMap["properties"].(map[string]any)
|
||||
additional := schemaMap["additionalProperties"]
|
||||
changed := false
|
||||
out := make(map[string]any, len(obj))
|
||||
for key, current := range obj {
|
||||
next := current
|
||||
var fieldChanged bool
|
||||
if propSchema, ok := properties[key]; ok {
|
||||
next, fieldChanged = normalizeToolValueWithSchema(current, propSchema)
|
||||
} else if additional != nil {
|
||||
next, fieldChanged = normalizeToolValueWithSchema(current, additional)
|
||||
}
|
||||
out[key] = next
|
||||
changed = changed || fieldChanged
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
if looksLikeArraySchema(schemaMap) {
|
||||
arr, ok := value.([]any)
|
||||
if !ok || len(arr) == 0 {
|
||||
return value, false
|
||||
}
|
||||
itemsSchema := schemaMap["items"]
|
||||
if itemsSchema == nil {
|
||||
return value, false
|
||||
}
|
||||
changed := false
|
||||
out := make([]any, len(arr))
|
||||
switch itemSchemas := itemsSchema.(type) {
|
||||
case []any:
|
||||
for i, item := range arr {
|
||||
if i >= len(itemSchemas) {
|
||||
out[i] = item
|
||||
continue
|
||||
}
|
||||
next, itemChanged := normalizeToolValueWithSchema(item, itemSchemas[i])
|
||||
out[i] = next
|
||||
changed = changed || itemChanged
|
||||
}
|
||||
default:
|
||||
for i, item := range arr {
|
||||
next, itemChanged := normalizeToolValueWithSchema(item, itemsSchema)
|
||||
out[i] = next
|
||||
changed = changed || itemChanged
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
return value, false
|
||||
}
|
||||
|
||||
func shouldCoerceSchemaToString(schema map[string]any) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
if isStringConst(schema["const"]) {
|
||||
return true
|
||||
}
|
||||
if isStringEnum(schema["enum"]) {
|
||||
return true
|
||||
}
|
||||
switch v := schema["type"].(type) {
|
||||
case string:
|
||||
return strings.EqualFold(strings.TrimSpace(v), "string")
|
||||
case []any:
|
||||
return isOnlyStringLikeTypes(v)
|
||||
case []string:
|
||||
items := make([]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
items = append(items, item)
|
||||
}
|
||||
return isOnlyStringLikeTypes(items)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeObjectSchema(schema map[string]any) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "object") {
|
||||
return true
|
||||
}
|
||||
if _, ok := schema["properties"].(map[string]any); ok {
|
||||
return true
|
||||
}
|
||||
_, hasAdditional := schema["additionalProperties"]
|
||||
return hasAdditional
|
||||
}
|
||||
|
||||
func looksLikeArraySchema(schema map[string]any) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "array") {
|
||||
return true
|
||||
}
|
||||
_, hasItems := schema["items"]
|
||||
return hasItems
|
||||
}
|
||||
|
||||
func isOnlyStringLikeTypes(values []any) bool {
|
||||
if len(values) == 0 {
|
||||
return false
|
||||
}
|
||||
hasString := false
|
||||
for _, item := range values {
|
||||
typ, ok := item.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(typ)) {
|
||||
case "string":
|
||||
hasString = true
|
||||
case "null":
|
||||
continue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return hasString
|
||||
}
|
||||
|
||||
func isStringConst(v any) bool {
|
||||
_, ok := v.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func isStringEnum(v any) bool {
|
||||
values, ok := v.([]any)
|
||||
if !ok || len(values) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, item := range values {
|
||||
if _, ok := item.(string); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringifySchemaValue(value any) (any, bool) {
|
||||
if value == nil {
|
||||
return value, false
|
||||
}
|
||||
if s, ok := value.(string); ok {
|
||||
return s, false
|
||||
}
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return value, false
|
||||
}
|
||||
return string(b), true
|
||||
}
|
||||
|
||||
func asStringValue(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
112
internal/toolcall/toolcalls_schema_normalize_test.go
Normal file
112
internal/toolcall/toolcalls_schema_normalize_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeParsedToolCallsForSchemasCoercesDeclaredStringFieldsRecursively(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "TaskUpdate",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
"payload": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"tags": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "string"},
|
||||
},
|
||||
"count": map[string]any{"type": "number"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
calls := []ParsedToolCall{{
|
||||
Name: "TaskUpdate",
|
||||
Input: map[string]any{
|
||||
"taskId": 1,
|
||||
"payload": map[string]any{
|
||||
"content": map[string]any{"text": "hello"},
|
||||
"tags": []any{1, true, map[string]any{"k": "v"}},
|
||||
"count": 2,
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized call, got %#v", got)
|
||||
}
|
||||
if got[0].Input["taskId"] != "1" {
|
||||
t.Fatalf("expected taskId coerced to string, got %#v", got[0].Input["taskId"])
|
||||
}
|
||||
payload, ok := got[0].Input["payload"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected payload object, got %#v", got[0].Input["payload"])
|
||||
}
|
||||
if payload["content"] != `{"text":"hello"}` {
|
||||
t.Fatalf("expected nested content coerced to json string, got %#v", payload["content"])
|
||||
}
|
||||
if payload["count"] != 2 {
|
||||
t.Fatalf("expected non-string count unchanged, got %#v", payload["count"])
|
||||
}
|
||||
tags, ok := payload["tags"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected tags slice, got %#v", payload["tags"])
|
||||
}
|
||||
wantTags := []any{"1", "true", `{"k":"v"}`}
|
||||
if !reflect.DeepEqual(tags, wantTags) {
|
||||
t.Fatalf("unexpected normalized tags: got %#v want %#v", tags, wantTags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeParsedToolCallsForSchemasSupportsDirectToolSchemaShape(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"name": "Write",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
calls := []ParsedToolCall{{Name: "Write", Input: map[string]any{"content": []any{"a", 1}}}}
|
||||
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
if got[0].Input["content"] != `["a",1]` {
|
||||
t.Fatalf("expected direct-schema content coerced to string, got %#v", got[0].Input["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeParsedToolCallsForSchemasLeavesAmbiguousUnionUnchanged(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "TaskUpdate",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"taskId": map[string]any{"type": []any{"string", "integer"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
calls := []ParsedToolCall{{Name: "TaskUpdate", Input: map[string]any{"taskId": 1}}}
|
||||
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
if got[0].Input["taskId"] != 1 {
|
||||
t.Fatalf("expected ambiguous union to stay unchanged, got %#v", got[0].Input["taskId"])
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
|
||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}}, nil)
|
||||
if len(formatted) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(formatted))
|
||||
}
|
||||
@@ -53,6 +53,21 @@ func TestParseToolCallsSupportsDSMLShellWithCanonicalExampleInCDATA(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsPreservesSimpleCDATAInlineMarkupAsText(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="Write"><parameter name="description"><![CDATA[<b>urgent</b>]]></parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"Write"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
got, ok := calls[0].Input["description"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected description to remain a string, got %#v", calls[0].Input["description"])
|
||||
}
|
||||
if got != "<b>urgent</b>" {
|
||||
t.Fatalf("expected inline markup CDATA to stay raw, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsTreatsUnclosedCDATAAsText(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[hello world</parameter></invoke></tool_calls>`
|
||||
res := ParseToolCallsDetailed(text, []string{"Write"})
|
||||
@@ -218,6 +233,21 @@ func TestParseToolCallsTreatsCDATAItemOnlyBodyAsArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsTreatsSingleItemCDATAAsArray(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="TodoWrite"><parameter name="todos"><![CDATA[<item>one</item>]]></parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"TodoWrite"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one TodoWrite call, got %#v", calls)
|
||||
}
|
||||
items, ok := calls[0].Input["todos"].([]any)
|
||||
if !ok || len(items) != 1 {
|
||||
t.Fatalf("expected single-item CDATA body to parse as array, got %#v", calls[0].Input["todos"])
|
||||
}
|
||||
if got, ok := items[0].(string); !ok || got != "one" {
|
||||
t.Fatalf("expected single item value to stay intact, got %#v", items[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsTreatsCDATAObjectFragmentAsObject(t *testing.T) {
|
||||
payload := `<question><![CDATA[Pick one]]></question><options><item><label><![CDATA[A]]></label></item><item><label><![CDATA[B]]></label></item></options>`
|
||||
text := `<tool_calls><invoke name="AskUserQuestion"><parameter name="questions"><![CDATA[` + payload + `]]></parameter></invoke></tool_calls>`
|
||||
|
||||
@@ -154,6 +154,7 @@ func findPartialXMLToolTagStart(s string) int {
|
||||
"<|tool_calls", "<|invoke", "<|parameter",
|
||||
"<|tool_calls", "<|invoke", "<|parameter",
|
||||
"<|dsml|tool_calls", "<|dsml|invoke", "<|dsml|parameter",
|
||||
"<|dsml|tool_calls", "<|dsml|invoke", "<|dsml|parameter",
|
||||
"<dsmltool_calls", "<dsmlinvoke", "<dsmlparameter",
|
||||
"<dsml tool_calls", "<dsml invoke", "<dsml parameter",
|
||||
"<dsml|tool_calls", "<dsml|invoke", "<dsml|parameter",
|
||||
|
||||
@@ -15,6 +15,7 @@ var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)((?:<tool_calls\b|<\|dsml
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{
|
||||
"<|dsml|tool_calls>", "<|dsml|tool_calls\n", "<|dsml|tool_calls ",
|
||||
"<|dsml|tool_calls>", "<|dsml|tool_calls\n", "<|dsml|tool_calls ",
|
||||
"<|dsml|invoke ", "<|dsml|invoke\n", "<|dsml|invoke\t", "<|dsml|invoke\r",
|
||||
"<|dsmltool_calls>", "<|dsmltool_calls\n", "<|dsmltool_calls ",
|
||||
"<|dsmlinvoke ", "<|dsmlinvoke\n", "<|dsmlinvoke\t", "<|dsmlinvoke\r",
|
||||
|
||||
@@ -745,6 +745,51 @@ func TestProcessToolSieveFullwidthPipeVariantDoesNotLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test <|DSML|tool_calls> with DSML invoke/parameter tags should buffer the
|
||||
// wrapper instead of leaking it before the block is complete.
|
||||
func TestProcessToolSieveFullwidthDSMLPrefixVariantDoesNotLeak(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<|DSML|tool",
|
||||
"_calls>\n",
|
||||
"<|DSML|invoke name=\"Bash\">\n",
|
||||
"<|DSML|parameter name=\"command\"><![CDATA[ls -la /Users/aq/Desktop/myproject/ds2api/]]></|DSML|parameter>\n",
|
||||
"<|DSML|parameter name=\"description\"><![CDATA[List project root contents]]></|DSML|parameter>\n",
|
||||
"</|DSML|invoke>\n",
|
||||
"<|DSML|invoke name=\"Bash\">\n",
|
||||
"<|DSML|parameter name=\"command\"><![CDATA[cat /Users/aq/Desktop/myproject/ds2api/package.json 2>/dev/null || echo \"No package.json found\"]]></|DSML|parameter>\n",
|
||||
"<|DSML|parameter name=\"description\"><![CDATA[Check for existing package.json]]></|DSML|parameter>\n",
|
||||
"</|DSML|invoke>\n",
|
||||
"</|DSML|tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"Bash"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"Bash"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
var toolCalls int
|
||||
var names []string
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
for _, call := range evt.ToolCalls {
|
||||
toolCalls++
|
||||
names = append(names, call.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if toolCalls != 2 {
|
||||
t.Fatalf("expected two tool calls from fullwidth DSML prefix variant, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if len(names) != 2 || names[0] != "Bash" || names[1] != "Bash" {
|
||||
t.Fatalf("expected two Bash tool calls, got %v", names)
|
||||
}
|
||||
if textContent.Len() != 0 {
|
||||
t.Fatalf("expected fullwidth DSML prefix variant not to leak text, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Test <DSML|tool_calls> with <|DSML|invoke> (DSML prefix without leading pipe on wrapper).
|
||||
func TestProcessToolSieveDSMLPrefixVariantDoesNotLeak(t *testing.T) {
|
||||
var state State
|
||||
|
||||
@@ -20,7 +20,7 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
|
||||
@@ -104,6 +104,13 @@ test('parseToolCalls keeps canonical XML examples inside DSML CDATA', () => {
|
||||
assert.deepEqual(calls[0].input, { path: 'notes.md', content });
|
||||
});
|
||||
|
||||
test('parseToolCalls preserves simple inline markup inside CDATA as text', () => {
|
||||
const payload = '<tool_calls><invoke name="Write"><parameter name="description"><![CDATA[<b>urgent</b>]]></parameter></invoke></tool_calls>';
|
||||
const calls = parseToolCalls(payload, ['Write']);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].input.description, '<b>urgent</b>');
|
||||
});
|
||||
|
||||
test('parseToolCalls recovers when CDATA never closes inside a valid wrapper', () => {
|
||||
const payload = '<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[hello world</parameter></invoke></tool_calls>';
|
||||
const calls = parseToolCalls(payload, ['Write']);
|
||||
@@ -174,6 +181,13 @@ test('parseToolCalls treats CDATA item-only body as array', () => {
|
||||
]);
|
||||
});
|
||||
|
||||
test('parseToolCalls treats single-item CDATA body as array', () => {
|
||||
const payload = '<tool_calls><invoke name="TodoWrite"><parameter name="todos"><![CDATA[<item>one</item>]]></parameter></invoke></tool_calls>';
|
||||
const calls = parseToolCalls(payload, ['TodoWrite']);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.deepEqual(calls[0].input.todos, ['one']);
|
||||
});
|
||||
|
||||
test('parseToolCalls treats CDATA object fragment as object', () => {
|
||||
const fragment = '<question><![CDATA[Pick one]]></question><options><item><label><![CDATA[A]]></label></item><item><label><![CDATA[B]]></label></item></options>';
|
||||
const payload = `<tool_calls><invoke name="AskUserQuestion"><parameter name="questions"><![CDATA[${fragment}]]></parameter></invoke></tool_calls>`;
|
||||
@@ -400,6 +414,31 @@ test('sieve emits tool_calls when DSML tag spans multiple chunks', () => {
|
||||
assert.equal(finalCalls[0].name, 'read_file');
|
||||
});
|
||||
|
||||
test('sieve emits tool_calls when fullwidth DSML prefix variant spans multiple chunks', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'<|DSML|tool',
|
||||
'_calls>\n',
|
||||
'<|DSML|invoke name="Bash">\n',
|
||||
'<|DSML|parameter name="command"><![CDATA[ls -la /Users/aq/Desktop/myproject/ds2api/]]></|DSML|parameter>\n',
|
||||
'<|DSML|parameter name="description"><![CDATA[List project root contents]]></|DSML|parameter>\n',
|
||||
'</|DSML|invoke>\n',
|
||||
'<|DSML|invoke name="Bash">\n',
|
||||
'<|DSML|parameter name="command"><![CDATA[cat /Users/aq/Desktop/myproject/ds2api/package.json 2>/dev/null || echo "No package.json found"]]></|DSML|parameter>\n',
|
||||
'<|DSML|parameter name="description"><![CDATA[Check for existing package.json]]></|DSML|parameter>\n',
|
||||
'</|DSML|invoke>\n',
|
||||
'</|DSML|tool_calls>',
|
||||
],
|
||||
['Bash'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
assert.equal(leakedText, '');
|
||||
assert.equal(finalCalls.length, 2);
|
||||
assert.equal(finalCalls[0].name, 'Bash');
|
||||
assert.equal(finalCalls[1].name, 'Bash');
|
||||
});
|
||||
|
||||
test('sieve keeps long XML tool calls buffered until the closing tag arrives', () => {
|
||||
const longContent = 'x'.repeat(4096);
|
||||
const splitAt = longContent.length / 2;
|
||||
|
||||
Reference in New Issue
Block a user