mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-10 03:07:41 +08:00
Compare commits
10 Commits
v4.1.1-2
...
v4.1.2_bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f18e6b9b11 | ||
|
|
40ebc8e942 | ||
|
|
fa3e6d040d | ||
|
|
458e4469e5 | ||
|
|
72c8e7e9f9 | ||
|
|
b9c8e90d98 | ||
|
|
36fcba1280 | ||
|
|
801b5abce3 | ||
|
|
206c3d5479 | ||
|
|
b2903c35ed |
@@ -153,6 +153,7 @@ OpenAI Chat / Responses 在标准化后、current input file 之前,会默认
|
|||||||
工具调用正例现在优先示范官方 DSML 风格:`<|DSML|tool_calls>` → `<|DSML|invoke name="...">` → `<|DSML|parameter name="...">`。
|
工具调用正例现在优先示范官方 DSML 风格:`<|DSML|tool_calls>` → `<|DSML|invoke name="...">` → `<|DSML|parameter name="...">`。
|
||||||
兼容层仍接受旧式纯 `<tool_calls>` wrapper,但提示词会优先要求模型输出官方 DSML 标签,并强调不能只输出 closing wrapper 而漏掉 opening tag。需要注意:这是“兼容 DSML 外壳,内部仍以 XML 解析语义为准”,不是原生 DSML 全链路实现;DSML 标签会在解析入口归一化回现有 XML 标签后继续走同一套 parser。
|
兼容层仍接受旧式纯 `<tool_calls>` wrapper,但提示词会优先要求模型输出官方 DSML 标签,并强调不能只输出 closing wrapper 而漏掉 opening tag。需要注意:这是“兼容 DSML 外壳,内部仍以 XML 解析语义为准”,不是原生 DSML 全链路实现;DSML 标签会在解析入口归一化回现有 XML 标签后继续走同一套 parser。
|
||||||
数组参数使用 `<item>...</item>` 子节点表示;当某个参数体只包含 item 子节点时,Go / Node 解析器会把它还原成数组,避免 `questions` / `options` 这类 schema 中要求 array 的参数被误解析成 `{ "item": ... }` 对象。若模型把完整结构化 XML fragment 误包进 CDATA,兼容层会在保护 `content` / `command` 等原文字段的前提下,尝试把非原文字段中的 CDATA XML fragment 还原成 object / array。不过,如果 CDATA 只是单个平面的 XML/HTML 标签,例如 `<b>urgent</b>` 这种行内标记,兼容层会保留原始字符串,不会强行升成 object / array;只有明显表示结构的 CDATA 片段,例如多兄弟节点、嵌套子节点或 `item` 列表,才会触发结构化恢复。
|
数组参数使用 `<item>...</item>` 子节点表示;当某个参数体只包含 item 子节点时,Go / Node 解析器会把它还原成数组,避免 `questions` / `options` 这类 schema 中要求 array 的参数被误解析成 `{ "item": ... }` 对象。若模型把完整结构化 XML fragment 误包进 CDATA,兼容层会在保护 `content` / `command` 等原文字段的前提下,尝试把非原文字段中的 CDATA XML fragment 还原成 object / array。不过,如果 CDATA 只是单个平面的 XML/HTML 标签,例如 `<b>urgent</b>` 这种行内标记,兼容层会保留原始字符串,不会强行升成 object / array;只有明显表示结构的 CDATA 片段,例如多兄弟节点、嵌套子节点或 `item` 列表,才会触发结构化恢复。
|
||||||
|
在 assistant 最终回包阶段,如果某个 tool 参数在声明 schema 中明确是 `string`,兼容层会在把解析后的 `tool_calls` / `function_call` 重新序列化成 OpenAI / Responses / Claude 可见参数前,递归把该路径上的 number / bool / object / array 统一转成字符串;其中 object / array 会压成紧凑 JSON 字符串。这个保护只对 schema 明确声明为 string 的路径生效,不会改写本来就是 `number` / `boolean` / `object` / `array` 的参数。这样可以兼容 DeepSeek 输出了结构化片段、但上游客户端工具 schema 又严格要求字符串参数的场景(例如 `content`、`prompt`、`path`、`taskId` 等)。
|
||||||
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
|
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
|
||||||
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。
|
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
|
||||||
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
||||||
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
|
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
|
||||||
finishReason := "stop"
|
finishReason := "stop"
|
||||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||||
if strings.TrimSpace(finalThinking) != "" {
|
if strings.TrimSpace(finalThinking) != "" {
|
||||||
@@ -19,7 +19,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi
|
|||||||
}
|
}
|
||||||
if len(detected) > 0 {
|
if len(detected) > 0 {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
|
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, toolsRaw)
|
||||||
messageObj["content"] = nil
|
messageObj["content"] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,19 +9,19 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
|
||||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||||
// as executable tool calls.
|
// as executable tool calls.
|
||||||
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
||||||
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
|
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
|
||||||
exposedOutputText := finalText
|
exposedOutputText := finalText
|
||||||
output := make([]any, 0, 2)
|
output := make([]any, 0, 2)
|
||||||
if len(detected) > 0 {
|
if len(detected) > 0 {
|
||||||
exposedOutputText = ""
|
exposedOutputText = ""
|
||||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
output = append(output, toResponsesFunctionCallItems(detected, toolsRaw)...)
|
||||||
} else {
|
} else {
|
||||||
content := make([]any, 0, 2)
|
content := make([]any, 0, 2)
|
||||||
if finalThinking != "" {
|
if finalThinking != "" {
|
||||||
@@ -74,12 +74,13 @@ func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall) []any {
|
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall, toolsRaw any) []any {
|
||||||
if len(toolCalls) == 0 {
|
if len(toolCalls) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(toolCalls, toolsRaw)
|
||||||
out := make([]any, 0, len(toolCalls))
|
out := make([]any, 0, len(toolCalls))
|
||||||
for _, tc := range toolCalls {
|
for _, tc := range normalizedCalls {
|
||||||
if strings.TrimSpace(tc.Name) == "" {
|
if strings.TrimSpace(tc.Name) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/toolcall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||||
@@ -13,6 +16,7 @@ func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
|||||||
"",
|
"",
|
||||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||||
[]string{"search"},
|
[]string{"search"},
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputText, _ := obj["output_text"].(string)
|
outputText, _ := obj["output_text"].(string)
|
||||||
@@ -42,6 +46,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
|||||||
"internal thinking content",
|
"internal thinking content",
|
||||||
"",
|
"",
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputText, _ := obj["output_text"].(string)
|
outputText, _ := obj["output_text"].(string)
|
||||||
@@ -75,6 +80,7 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
|
|||||||
`<tool_calls><invoke name="search"><parameter name="q">from-thinking</parameter></invoke></tool_calls>`,
|
`<tool_calls><invoke name="search"><parameter name="q">from-thinking</parameter></invoke></tool_calls>`,
|
||||||
"",
|
"",
|
||||||
[]string{"search"},
|
[]string{"search"},
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
output, _ := obj["output"].([]any)
|
output, _ := obj["output"].([]any)
|
||||||
@@ -86,3 +92,88 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
|
|||||||
t.Fatalf("expected function_call output, got %#v", first["type"])
|
t.Fatalf("expected function_call output, got %#v", first["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildChatCompletionWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "Write",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"content": map[string]any{"type": "string"},
|
||||||
|
"taskId": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
obj := BuildChatCompletionWithToolCalls(
|
||||||
|
"chat_test",
|
||||||
|
"gpt-4o",
|
||||||
|
"prompt",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
[]toolcall.ParsedToolCall{{
|
||||||
|
Name: "Write",
|
||||||
|
Input: map[string]any{
|
||||||
|
"content": map[string]any{"message": "hi"},
|
||||||
|
"taskId": 1,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
toolsRaw,
|
||||||
|
)
|
||||||
|
choices, _ := obj["choices"].([]map[string]any)
|
||||||
|
message, _ := choices[0]["message"].(map[string]any)
|
||||||
|
toolCalls, _ := message["tool_calls"].([]map[string]any)
|
||||||
|
fn, _ := toolCalls[0]["function"].(map[string]any)
|
||||||
|
args := map[string]any{}
|
||||||
|
if err := json.Unmarshal([]byte(fn["arguments"].(string)), &args); err != nil {
|
||||||
|
t.Fatalf("decode arguments failed: %v", err)
|
||||||
|
}
|
||||||
|
if args["content"] != `{"message":"hi"}` {
|
||||||
|
t.Fatalf("expected content stringified by schema, got %#v", args["content"])
|
||||||
|
}
|
||||||
|
if args["taskId"] != "1" {
|
||||||
|
t.Fatalf("expected taskId stringified by schema, got %#v", args["taskId"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildResponseObjectWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "Write",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"content": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
obj := BuildResponseObjectWithToolCalls(
|
||||||
|
"resp_test",
|
||||||
|
"gpt-4o",
|
||||||
|
"prompt",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
[]toolcall.ParsedToolCall{{
|
||||||
|
Name: "Write",
|
||||||
|
Input: map[string]any{"content": []any{"a", 1}},
|
||||||
|
}},
|
||||||
|
toolsRaw,
|
||||||
|
)
|
||||||
|
output, _ := obj["output"].([]any)
|
||||||
|
first, _ := output[0].(map[string]any)
|
||||||
|
args := map[string]any{}
|
||||||
|
if err := json.Unmarshal([]byte(first["arguments"].(string)), &args); err != nil {
|
||||||
|
t.Fatalf("decode response arguments failed: %v", err)
|
||||||
|
}
|
||||||
|
if args["content"] != `["a",1]` {
|
||||||
|
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
||||||
|
|
||||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, session)
|
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, nil, session)
|
||||||
|
|
||||||
snapshot, err := historyStore.Snapshot()
|
snapshot, err := historyStore.Snapshot()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type chatStreamRuntime struct {
|
|||||||
model string
|
model string
|
||||||
finalPrompt string
|
finalPrompt string
|
||||||
toolNames []string
|
toolNames []string
|
||||||
|
toolsRaw any
|
||||||
|
|
||||||
thinkingEnabled bool
|
thinkingEnabled bool
|
||||||
searchEnabled bool
|
searchEnabled bool
|
||||||
@@ -61,6 +62,7 @@ func newChatStreamRuntime(
|
|||||||
searchEnabled bool,
|
searchEnabled bool,
|
||||||
stripReferenceMarkers bool,
|
stripReferenceMarkers bool,
|
||||||
toolNames []string,
|
toolNames []string,
|
||||||
|
toolsRaw any,
|
||||||
bufferToolContent bool,
|
bufferToolContent bool,
|
||||||
emitEarlyToolDeltas bool,
|
emitEarlyToolDeltas bool,
|
||||||
) *chatStreamRuntime {
|
) *chatStreamRuntime {
|
||||||
@@ -73,6 +75,7 @@ func newChatStreamRuntime(
|
|||||||
model: model,
|
model: model,
|
||||||
finalPrompt: finalPrompt,
|
finalPrompt: finalPrompt,
|
||||||
toolNames: toolNames,
|
toolNames: toolNames,
|
||||||
|
toolsRaw: toolsRaw,
|
||||||
thinkingEnabled: thinkingEnabled,
|
thinkingEnabled: thinkingEnabled,
|
||||||
searchEnabled: searchEnabled,
|
searchEnabled: searchEnabled,
|
||||||
stripReferenceMarkers: stripReferenceMarkers,
|
stripReferenceMarkers: stripReferenceMarkers,
|
||||||
@@ -142,7 +145,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
|||||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
delta := map[string]any{
|
delta := map[string]any{
|
||||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
|
||||||
}
|
}
|
||||||
if !s.firstChunkSent {
|
if !s.firstChunkSent {
|
||||||
delta["role"] = "assistant"
|
delta["role"] = "assistant"
|
||||||
@@ -164,7 +167,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
|||||||
s.toolCallsEmitted = true
|
s.toolCallsEmitted = true
|
||||||
s.toolCallsDoneEmitted = true
|
s.toolCallsDoneEmitted = true
|
||||||
tcDelta := map[string]any{
|
tcDelta := map[string]any{
|
||||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||||
}
|
}
|
||||||
if !s.firstChunkSent {
|
if !s.firstChunkSent {
|
||||||
tcDelta["role"] = "assistant"
|
tcDelta["role"] = "assistant"
|
||||||
@@ -320,7 +323,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
|||||||
s.toolCallsEmitted = true
|
s.toolCallsEmitted = true
|
||||||
s.toolCallsDoneEmitted = true
|
s.toolCallsDoneEmitted = true
|
||||||
tcDelta := map[string]any{
|
tcDelta := map[string]any{
|
||||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||||
}
|
}
|
||||||
if !s.firstChunkSent {
|
if !s.firstChunkSent {
|
||||||
tcDelta["role"] = "assistant"
|
tcDelta["role"] = "assistant"
|
||||||
|
|||||||
@@ -26,14 +26,14 @@ type chatNonStreamResult struct {
|
|||||||
responseMessageID int
|
responseMessageID int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||||
attempts := 0
|
attempts := 0
|
||||||
currentResp := resp
|
currentResp := resp
|
||||||
usagePrompt := finalPrompt
|
usagePrompt := finalPrompt
|
||||||
accumulatedThinking := ""
|
accumulatedThinking := ""
|
||||||
accumulatedToolDetectionThinking := ""
|
accumulatedToolDetectionThinking := ""
|
||||||
for {
|
for {
|
||||||
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
|
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -43,7 +43,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
|||||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||||
detected := detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
detected := detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
||||||
result.detectedCalls = len(detected.Calls)
|
result.detectedCalls = len(detected.Calls)
|
||||||
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls)
|
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls, toolsRaw)
|
||||||
result.finishReason = chatFinishReason(result.body)
|
result.finishReason = chatFinishReason(result.body)
|
||||||
if !shouldRetryChatNonStream(result, attempts) {
|
if !shouldRetryChatNonStream(result, attempts) {
|
||||||
h.finishChatNonStreamResult(w, result, attempts, usagePrompt, historySession)
|
h.finishChatNonStreamResult(w, result, attempts, usagePrompt, historySession)
|
||||||
@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (chatNonStreamResult, bool) {
|
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (chatNonStreamResult, bool) {
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -88,7 +88,7 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
|
|||||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||||
}
|
}
|
||||||
detected := detectAssistantToolCalls(finalText, finalThinking, finalToolDetectionThinking, toolNames)
|
detected := detectAssistantToolCalls(finalText, finalThinking, finalToolDetectionThinking, toolNames)
|
||||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls)
|
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||||
return chatNonStreamResult{
|
return chatNonStreamResult{
|
||||||
thinking: finalThinking,
|
thinking: finalThinking,
|
||||||
toolDetectionThinking: finalToolDetectionThinking,
|
toolDetectionThinking: finalToolDetectionThinking,
|
||||||
@@ -139,8 +139,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
|
|||||||
strings.TrimSpace(result.text) == ""
|
strings.TrimSpace(result.text) == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, historySession)
|
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -182,7 +182,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -207,7 +207,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
|
|||||||
}
|
}
|
||||||
streamRuntime := newChatStreamRuntime(
|
streamRuntime := newChatStreamRuntime(
|
||||||
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
|
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
|
||||||
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames,
|
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
|
||||||
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||||
)
|
)
|
||||||
return streamRuntime, initialType, true
|
return streamRuntime, initialType, true
|
||||||
|
|||||||
@@ -144,8 +144,8 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
|
|||||||
return shared.FilterIncrementalToolCallDeltasByAllowed(deltas, seenNames)
|
return shared.FilterIncrementalToolCallDeltasByAllowed(deltas, seenNames)
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
|
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
|
||||||
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids)
|
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids, toolsRaw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func detectAssistantToolCalls(text, exposedThinking, detectionThinking string, toolNames []string) toolcall.ToolCallParseResult {
|
func detectAssistantToolCalls(text, exposedThinking, detectionThinking string, toolNames []string) toolcall.ToolCallParseResult {
|
||||||
|
|||||||
@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if stdReq.Stream {
|
if stdReq.Stream {
|
||||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||||
@@ -148,7 +148,7 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -176,7 +176,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
|||||||
writeUpstreamEmptyOutputError(w, finalText, finalThinking, result.ContentFilter)
|
writeUpstreamEmptyOutputError(w, finalText, finalThinking, result.ContentFilter)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||||
finishReason := "stop"
|
finishReason := "stop"
|
||||||
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
||||||
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
||||||
@@ -189,7 +189,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
|||||||
writeJSON(w, http.StatusOK, respBody)
|
writeJSON(w, http.StatusOK, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -230,6 +230,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
|||||||
searchEnabled,
|
searchEnabled,
|
||||||
stripReferenceMarkers,
|
stripReferenceMarkers,
|
||||||
toolNames,
|
toolNames,
|
||||||
|
toolsRaw,
|
||||||
bufferToolContent,
|
bufferToolContent,
|
||||||
emitEarlyToolDeltas,
|
emitEarlyToolDeltas,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil)
|
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
|
||||||
if rec.Code != http.StatusTooManyRequests {
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -112,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil)
|
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
|
||||||
if rec.Code != http.StatusBadRequest {
|
if rec.Code != http.StatusBadRequest {
|
||||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -131,7 +131,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil)
|
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil, nil)
|
||||||
if rec.Code != http.StatusTooManyRequests {
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -150,7 +150,7 @@ func TestHandleNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testing.T) {
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
|
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -181,7 +181,7 @@ func TestHandleNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpty(t *test
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
|
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -211,7 +211,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
|
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
|
||||||
|
|
||||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
if !done {
|
if !done {
|
||||||
@@ -248,7 +248,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
|
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
|
||||||
|
|
||||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
if !done {
|
if !done {
|
||||||
@@ -282,7 +282,7 @@ func TestHandleStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstreamIntercep
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
|
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
|
||||||
|
|
||||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
if !done {
|
if !done {
|
||||||
@@ -319,7 +319,7 @@ func TestHandleStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *testing.T)
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
|
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
|
||||||
|
|
||||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
if !done {
|
if !done {
|
||||||
@@ -353,7 +353,7 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil)
|
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, nil)
|
||||||
|
|
||||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
if !done {
|
if !done {
|
||||||
@@ -390,3 +390,64 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
|
|||||||
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamCoercesSchemaDeclaredStringArgumentsOnFinalize(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
line := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
|
||||||
|
return "data: " + string(b)
|
||||||
|
}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
line(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`),
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "Write",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"content": map[string]any{"type": "string"},
|
||||||
|
"taskId": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid-string-protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, nil)
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
toolCalls, _ := delta["tool_calls"].([]any)
|
||||||
|
if len(toolCalls) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
call, _ := toolCalls[0].(map[string]any)
|
||||||
|
fn, _ := call["function"].(map[string]any)
|
||||||
|
args := map[string]any{}
|
||||||
|
if err := json.Unmarshal([]byte(asString(fn["arguments"])), &args); err != nil {
|
||||||
|
t.Fatalf("decode streamed tool arguments failed: %v", err)
|
||||||
|
}
|
||||||
|
if args["content"] != `{"message":"hi"}` {
|
||||||
|
t.Fatalf("expected streamed content stringified by schema, got %#v", args["content"])
|
||||||
|
}
|
||||||
|
if args["taskId"] != "1" {
|
||||||
|
t.Fatalf("expected streamed taskId stringified by schema, got %#v", args["taskId"])
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("expected at least one streamed tool call delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,14 +27,14 @@ type responsesNonStreamResult struct {
|
|||||||
responseMessageID int
|
responseMessageID int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||||
attempts := 0
|
attempts := 0
|
||||||
currentResp := resp
|
currentResp := resp
|
||||||
usagePrompt := finalPrompt
|
usagePrompt := finalPrompt
|
||||||
accumulatedThinking := ""
|
accumulatedThinking := ""
|
||||||
accumulatedToolDetectionThinking := ""
|
accumulatedToolDetectionThinking := ""
|
||||||
for {
|
for {
|
||||||
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
|
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -43,7 +43,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
|||||||
result.thinking = accumulatedThinking
|
result.thinking = accumulatedThinking
|
||||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||||
result.parsed = detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
result.parsed = detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
||||||
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls)
|
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
|
||||||
|
|
||||||
if !shouldRetryResponsesNonStream(result, attempts) {
|
if !shouldRetryResponsesNonStream(result, attempts) {
|
||||||
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
|
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
|
||||||
@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (responsesNonStreamResult, bool) {
|
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -84,7 +84,7 @@ func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *
|
|||||||
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
||||||
}
|
}
|
||||||
textParsed := detectAssistantToolCalls(sanitizedText, sanitizedThinking, toolDetectionThinking, toolNames)
|
textParsed := detectAssistantToolCalls(sanitizedText, sanitizedThinking, toolDetectionThinking, toolNames)
|
||||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
|
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||||
return responsesNonStreamResult{
|
return responsesNonStreamResult{
|
||||||
thinking: sanitizedThinking,
|
thinking: sanitizedThinking,
|
||||||
toolDetectionThinking: toolDetectionThinking,
|
toolDetectionThinking: toolDetectionThinking,
|
||||||
@@ -123,8 +123,8 @@ func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int
|
|||||||
strings.TrimSpace(result.text) == ""
|
strings.TrimSpace(result.text) == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||||
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolChoice, traceID)
|
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
|
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -184,7 +184,7 @@ func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *htt
|
|||||||
}
|
}
|
||||||
streamRuntime := newResponsesStreamRuntime(
|
streamRuntime := newResponsesStreamRuntime(
|
||||||
w, rc, canFlush, responseID, model, finalPrompt, thinkingEnabled, searchEnabled,
|
w, rc, canFlush, responseID, model, finalPrompt, thinkingEnabled, searchEnabled,
|
||||||
h.compatStripReferenceMarkers(), toolNames, len(toolNames) > 0,
|
h.compatStripReferenceMarkers(), toolNames, toolsRaw, len(toolNames) > 0,
|
||||||
h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||||
toolChoice, traceID, func(obj map[string]any) {
|
toolChoice, traceID, func(obj map[string]any) {
|
||||||
h.getResponseStore().put(owner, responseID, obj)
|
h.getResponseStore().put(owner, responseID, obj)
|
||||||
|
|||||||
@@ -115,13 +115,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
if stdReq.Stream {
|
if stdReq.Stream {
|
||||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -148,12 +148,12 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
|
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||||
h.getResponseStore().put(owner, responseID, responseObj)
|
h.getResponseStore().put(owner, responseID, responseObj)
|
||||||
writeJSON(w, http.StatusOK, responseObj)
|
writeJSON(w, http.StatusOK, responseObj)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -186,6 +186,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
|||||||
searchEnabled,
|
searchEnabled,
|
||||||
stripReferenceMarkers,
|
stripReferenceMarkers,
|
||||||
toolNames,
|
toolNames,
|
||||||
|
toolsRaw,
|
||||||
bufferToolContent,
|
bufferToolContent,
|
||||||
emitEarlyToolDeltas,
|
emitEarlyToolDeltas,
|
||||||
toolChoice,
|
toolChoice,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ type responsesStreamRuntime struct {
|
|||||||
model string
|
model string
|
||||||
finalPrompt string
|
finalPrompt string
|
||||||
toolNames []string
|
toolNames []string
|
||||||
|
toolsRaw any
|
||||||
traceID string
|
traceID string
|
||||||
toolChoice promptcompat.ToolChoicePolicy
|
toolChoice promptcompat.ToolChoicePolicy
|
||||||
|
|
||||||
@@ -72,6 +73,7 @@ func newResponsesStreamRuntime(
|
|||||||
searchEnabled bool,
|
searchEnabled bool,
|
||||||
stripReferenceMarkers bool,
|
stripReferenceMarkers bool,
|
||||||
toolNames []string,
|
toolNames []string,
|
||||||
|
toolsRaw any,
|
||||||
bufferToolContent bool,
|
bufferToolContent bool,
|
||||||
emitEarlyToolDeltas bool,
|
emitEarlyToolDeltas bool,
|
||||||
toolChoice promptcompat.ToolChoicePolicy,
|
toolChoice promptcompat.ToolChoicePolicy,
|
||||||
@@ -89,6 +91,7 @@ func newResponsesStreamRuntime(
|
|||||||
searchEnabled: searchEnabled,
|
searchEnabled: searchEnabled,
|
||||||
stripReferenceMarkers: stripReferenceMarkers,
|
stripReferenceMarkers: stripReferenceMarkers,
|
||||||
toolNames: toolNames,
|
toolNames: toolNames,
|
||||||
|
toolsRaw: toolsRaw,
|
||||||
bufferToolContent: bufferToolContent,
|
bufferToolContent: bufferToolContent,
|
||||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||||
streamToolCallIDs: map[int]string{},
|
streamToolCallIDs: map[int]string{},
|
||||||
|
|||||||
@@ -220,7 +220,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolstream
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []toolcall.ParsedToolCall) {
|
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []toolcall.ParsedToolCall) {
|
||||||
for idx, tc := range calls {
|
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
|
||||||
|
for idx, tc := range normalizedCalls {
|
||||||
if strings.TrimSpace(tc.Name) == "" {
|
if strings.TrimSpace(tc.Name) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,7 +109,8 @@ func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, fin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, tc := range calls {
|
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
|
||||||
|
for idx, tc := range normalizedCalls {
|
||||||
if strings.TrimSpace(tc.Name) == "" {
|
if strings.TrimSpace(tc.Name) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
|
|||||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||||
@@ -57,7 +57,7 @@ func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testin
|
|||||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.output_text.done") {
|
if !strings.Contains(body, "event: response.output_text.done") {
|
||||||
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||||
@@ -91,7 +91,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
|||||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
|
|
||||||
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
||||||
@@ -130,7 +130,7 @@ func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t
|
|||||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||||
@@ -183,7 +183,7 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
|||||||
Mode: promptcompat.ToolChoiceRequired,
|
Mode: promptcompat.ToolChoiceRequired,
|
||||||
Allowed: map[string]struct{}{"read_file": {}},
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
}
|
}
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.failed") {
|
if !strings.Contains(body, "event: response.failed") {
|
||||||
@@ -213,7 +213,7 @@ func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
|
|||||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.failed") {
|
if !strings.Contains(body, "event: response.failed") {
|
||||||
@@ -251,7 +251,7 @@ func TestHandleResponsesStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstrea
|
|||||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||||
@@ -288,7 +288,7 @@ func TestHandleResponsesStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *t
|
|||||||
Mode: promptcompat.ToolChoiceRequired,
|
Mode: promptcompat.ToolChoiceRequired,
|
||||||
Allowed: map[string]struct{}{"read_file": {}},
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
}
|
}
|
||||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
if strings.Contains(body, "event: response.reasoning.delta") {
|
if strings.Contains(body, "event: response.reasoning.delta") {
|
||||||
@@ -317,7 +317,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
|||||||
Allowed: map[string]struct{}{"read_file": {}},
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||||
if rec.Code != http.StatusUnprocessableEntity {
|
if rec.Code != http.StatusUnprocessableEntity {
|
||||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -344,7 +344,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayloadWhe
|
|||||||
Allowed: map[string]struct{}{"read_file": {}},
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, policy, "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, nil, policy, "")
|
||||||
if rec.Code != http.StatusUnprocessableEntity {
|
if rec.Code != http.StatusUnprocessableEntity {
|
||||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -366,7 +366,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
if rec.Code != http.StatusTooManyRequests {
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -388,7 +388,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
if rec.Code != http.StatusBadRequest {
|
if rec.Code != http.StatusBadRequest {
|
||||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -410,7 +410,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testin
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
if rec.Code != http.StatusTooManyRequests {
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -432,7 +432,7 @@ func TestHandleResponsesNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testi
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -462,7 +462,7 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
|
|||||||
Mode: promptcompat.ToolChoiceRequired,
|
Mode: promptcompat.ToolChoiceRequired,
|
||||||
Allowed: map[string]struct{}{"read_file": {}},
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
}
|
}
|
||||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -480,6 +480,53 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "Write",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"content": map[string]any{"type": "string"},
|
||||||
|
"taskId": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
streamBody := sseLine(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_string_protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, promptcompat.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
|
payload, ok := extractSSEEventPayload(rec.Body.String(), "response.function_call_arguments.done")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected response.function_call_arguments.done payload, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
args := map[string]any{}
|
||||||
|
if err := json.Unmarshal([]byte(asString(payload["arguments"])), &args); err != nil {
|
||||||
|
t.Fatalf("decode streamed response arguments failed: %v", err)
|
||||||
|
}
|
||||||
|
if args["content"] != `{"message":"hi"}` {
|
||||||
|
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
|
||||||
|
}
|
||||||
|
if args["taskId"] != "1" {
|
||||||
|
t.Fatalf("expected response taskId stringified by schema, got %#v", args["taskId"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||||
matched := false
|
matched := false
|
||||||
|
|||||||
@@ -70,12 +70,13 @@ func FilterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
|
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
|
||||||
if len(calls) == 0 {
|
if len(calls) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||||
out := make([]map[string]any, 0, len(calls))
|
out := make([]map[string]any, 0, len(calls))
|
||||||
for i, c := range calls {
|
for i, c := range normalizedCalls {
|
||||||
callID := ""
|
callID := ""
|
||||||
if ids != nil {
|
if ids != nil {
|
||||||
callID = strings.TrimSpace(ids[i])
|
callID = strings.TrimSpace(ids[i])
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||||
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
|
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
|
||||||
{Name: "search", Input: map[string]any{"q": "test"}},
|
{Name: "search", Input: map[string]any{"q": "test"}},
|
||||||
})
|
}, nil)
|
||||||
if len(formatted) != 1 {
|
if len(formatted) != 1 {
|
||||||
t.Fatalf("expected 1, got %d", len(formatted))
|
t.Fatalf("expected 1, got %d", len(formatted))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
func FormatOpenAIToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
|
||||||
|
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||||
out := make([]map[string]any, 0, len(calls))
|
out := make([]map[string]any, 0, len(calls))
|
||||||
for _, c := range calls {
|
for _, c := range normalized {
|
||||||
args, _ := json.Marshal(c.Input)
|
args, _ := json.Marshal(c.Input)
|
||||||
out = append(out, map[string]any{
|
out = append(out, map[string]any{
|
||||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||||
@@ -23,9 +24,10 @@ func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
|
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
|
||||||
|
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||||
out := make([]map[string]any, 0, len(calls))
|
out := make([]map[string]any, 0, len(calls))
|
||||||
for i, c := range calls {
|
for i, c := range normalized {
|
||||||
args, _ := json.Marshal(c.Input)
|
args, _ := json.Marshal(c.Input)
|
||||||
out = append(out, map[string]any{
|
out = append(out, map[string]any{
|
||||||
"index": i,
|
"index": i,
|
||||||
|
|||||||
266
internal/toolcall/toolcalls_schema_normalize.go
Normal file
266
internal/toolcall/toolcalls_schema_normalize.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
package toolcall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NormalizeParsedToolCallsForSchemas(calls []ParsedToolCall, toolsRaw any) []ParsedToolCall {
|
||||||
|
if len(calls) == 0 {
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
schemas := buildToolSchemaIndex(toolsRaw)
|
||||||
|
if len(schemas) == 0 {
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
|
||||||
|
var changedAny bool
|
||||||
|
out := make([]ParsedToolCall, len(calls))
|
||||||
|
for i, call := range calls {
|
||||||
|
out[i] = call
|
||||||
|
schema, ok := schemas[strings.ToLower(strings.TrimSpace(call.Name))]
|
||||||
|
if !ok || call.Input == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalized, changed := normalizeToolValueWithSchema(call.Input, schema)
|
||||||
|
if !changed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
changedAny = true
|
||||||
|
if input, ok := normalized.(map[string]any); ok {
|
||||||
|
out[i].Input = input
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changedAny {
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildToolSchemaIndex(toolsRaw any) map[string]any {
|
||||||
|
tools, ok := toolsRaw.([]any)
|
||||||
|
if !ok || len(tools) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(tools))
|
||||||
|
for _, item := range tools {
|
||||||
|
tool, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name, schema := extractToolNameAndSchema(tool)
|
||||||
|
if name == "" || schema == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[strings.ToLower(name)] = schema
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractToolNameAndSchema(tool map[string]any) (string, any) {
|
||||||
|
name := strings.TrimSpace(asStringValue(tool["name"]))
|
||||||
|
schema := tool["parameters"]
|
||||||
|
if schema == nil {
|
||||||
|
schema = tool["input_schema"]
|
||||||
|
}
|
||||||
|
if fn, ok := tool["function"].(map[string]any); ok {
|
||||||
|
if name == "" {
|
||||||
|
name = strings.TrimSpace(asStringValue(fn["name"]))
|
||||||
|
}
|
||||||
|
if schema == nil {
|
||||||
|
schema = fn["parameters"]
|
||||||
|
}
|
||||||
|
if schema == nil {
|
||||||
|
schema = fn["input_schema"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return name, schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolValueWithSchema(value any, schema any) (any, bool) {
|
||||||
|
if value == nil || schema == nil {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
schemaMap, ok := schema.(map[string]any)
|
||||||
|
if !ok || len(schemaMap) == 0 {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
if shouldCoerceSchemaToString(schemaMap) {
|
||||||
|
return stringifySchemaValue(value)
|
||||||
|
}
|
||||||
|
if looksLikeObjectSchema(schemaMap) {
|
||||||
|
obj, ok := value.(map[string]any)
|
||||||
|
if !ok || len(obj) == 0 {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
properties, _ := schemaMap["properties"].(map[string]any)
|
||||||
|
additional := schemaMap["additionalProperties"]
|
||||||
|
changed := false
|
||||||
|
out := make(map[string]any, len(obj))
|
||||||
|
for key, current := range obj {
|
||||||
|
next := current
|
||||||
|
var fieldChanged bool
|
||||||
|
if propSchema, ok := properties[key]; ok {
|
||||||
|
next, fieldChanged = normalizeToolValueWithSchema(current, propSchema)
|
||||||
|
} else if additional != nil {
|
||||||
|
next, fieldChanged = normalizeToolValueWithSchema(current, additional)
|
||||||
|
}
|
||||||
|
out[key] = next
|
||||||
|
changed = changed || fieldChanged
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
return out, true
|
||||||
|
}
|
||||||
|
if looksLikeArraySchema(schemaMap) {
|
||||||
|
arr, ok := value.([]any)
|
||||||
|
if !ok || len(arr) == 0 {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
itemsSchema := schemaMap["items"]
|
||||||
|
if itemsSchema == nil {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
changed := false
|
||||||
|
out := make([]any, len(arr))
|
||||||
|
switch itemSchemas := itemsSchema.(type) {
|
||||||
|
case []any:
|
||||||
|
for i, item := range arr {
|
||||||
|
if i >= len(itemSchemas) {
|
||||||
|
out[i] = item
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next, itemChanged := normalizeToolValueWithSchema(item, itemSchemas[i])
|
||||||
|
out[i] = next
|
||||||
|
changed = changed || itemChanged
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
for i, item := range arr {
|
||||||
|
next, itemChanged := normalizeToolValueWithSchema(item, itemsSchema)
|
||||||
|
out[i] = next
|
||||||
|
changed = changed || itemChanged
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
return out, true
|
||||||
|
}
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldCoerceSchemaToString(schema map[string]any) bool {
|
||||||
|
if schema == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isStringConst(schema["const"]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if isStringEnum(schema["enum"]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch v := schema["type"].(type) {
|
||||||
|
case string:
|
||||||
|
return strings.EqualFold(strings.TrimSpace(v), "string")
|
||||||
|
case []any:
|
||||||
|
return isOnlyStringLikeTypes(v)
|
||||||
|
case []string:
|
||||||
|
items := make([]any, 0, len(v))
|
||||||
|
for _, item := range v {
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
return isOnlyStringLikeTypes(items)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeObjectSchema(schema map[string]any) bool {
|
||||||
|
if schema == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "object") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := schema["properties"].(map[string]any); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, hasAdditional := schema["additionalProperties"]
|
||||||
|
return hasAdditional
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeArraySchema(schema map[string]any) bool {
|
||||||
|
if schema == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "array") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, hasItems := schema["items"]
|
||||||
|
return hasItems
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOnlyStringLikeTypes(values []any) bool {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
hasString := false
|
||||||
|
for _, item := range values {
|
||||||
|
typ, ok := item.(string)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(typ)) {
|
||||||
|
case "string":
|
||||||
|
hasString = true
|
||||||
|
case "null":
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hasString
|
||||||
|
}
|
||||||
|
|
||||||
|
func isStringConst(v any) bool {
|
||||||
|
_, ok := v.(string)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isStringEnum(v any) bool {
|
||||||
|
values, ok := v.([]any)
|
||||||
|
if !ok || len(values) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range values {
|
||||||
|
if _, ok := item.(string); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringifySchemaValue(value any) (any, bool) {
|
||||||
|
if value == nil {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
if s, ok := value.(string); ok {
|
||||||
|
return s, false
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
return string(b), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func asStringValue(v any) string {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
112
internal/toolcall/toolcalls_schema_normalize_test.go
Normal file
112
internal/toolcall/toolcalls_schema_normalize_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package toolcall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeParsedToolCallsForSchemasCoercesDeclaredStringFieldsRecursively(t *testing.T) {
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "TaskUpdate",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"taskId": map[string]any{"type": "string"},
|
||||||
|
"payload": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"content": map[string]any{"type": "string"},
|
||||||
|
"tags": map[string]any{
|
||||||
|
"type": "array",
|
||||||
|
"items": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
"count": map[string]any{"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
calls := []ParsedToolCall{{
|
||||||
|
Name: "TaskUpdate",
|
||||||
|
Input: map[string]any{
|
||||||
|
"taskId": 1,
|
||||||
|
"payload": map[string]any{
|
||||||
|
"content": map[string]any{"text": "hello"},
|
||||||
|
"tags": []any{1, true, map[string]any{"k": "v"}},
|
||||||
|
"count": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected one normalized call, got %#v", got)
|
||||||
|
}
|
||||||
|
if got[0].Input["taskId"] != "1" {
|
||||||
|
t.Fatalf("expected taskId coerced to string, got %#v", got[0].Input["taskId"])
|
||||||
|
}
|
||||||
|
payload, ok := got[0].Input["payload"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected payload object, got %#v", got[0].Input["payload"])
|
||||||
|
}
|
||||||
|
if payload["content"] != `{"text":"hello"}` {
|
||||||
|
t.Fatalf("expected nested content coerced to json string, got %#v", payload["content"])
|
||||||
|
}
|
||||||
|
if payload["count"] != 2 {
|
||||||
|
t.Fatalf("expected non-string count unchanged, got %#v", payload["count"])
|
||||||
|
}
|
||||||
|
tags, ok := payload["tags"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected tags slice, got %#v", payload["tags"])
|
||||||
|
}
|
||||||
|
wantTags := []any{"1", "true", `{"k":"v"}`}
|
||||||
|
if !reflect.DeepEqual(tags, wantTags) {
|
||||||
|
t.Fatalf("unexpected normalized tags: got %#v want %#v", tags, wantTags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeParsedToolCallsForSchemasSupportsDirectToolSchemaShape(t *testing.T) {
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "Write",
|
||||||
|
"input_schema": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"content": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
calls := []ParsedToolCall{{Name: "Write", Input: map[string]any{"content": []any{"a", 1}}}}
|
||||||
|
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||||
|
if got[0].Input["content"] != `["a",1]` {
|
||||||
|
t.Fatalf("expected direct-schema content coerced to string, got %#v", got[0].Input["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeParsedToolCallsForSchemasLeavesAmbiguousUnionUnchanged(t *testing.T) {
|
||||||
|
toolsRaw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "TaskUpdate",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"taskId": map[string]any{"type": []any{"string", "integer"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
calls := []ParsedToolCall{{Name: "TaskUpdate", Input: map[string]any{"taskId": 1}}}
|
||||||
|
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||||
|
if got[0].Input["taskId"] != 1 {
|
||||||
|
t.Fatalf("expected ambiguous union to stay unchanged, got %#v", got[0].Input["taskId"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestFormatOpenAIToolCalls(t *testing.T) {
|
func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
|
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}}, nil)
|
||||||
if len(formatted) != 1 {
|
if len(formatted) != 1 {
|
||||||
t.Fatalf("expected 1, got %d", len(formatted))
|
t.Fatalf("expected 1, got %d", len(formatted))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
|
|||||||
}
|
}
|
||||||
if len(detected) > 0 {
|
if len(detected) > 0 {
|
||||||
finishReason = "tool_calls"
|
finishReason = "tool_calls"
|
||||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
|
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
|
||||||
messageObj["content"] = nil
|
messageObj["content"] = nil
|
||||||
}
|
}
|
||||||
promptTokens := EstimateTokens(finalPrompt)
|
promptTokens := EstimateTokens(finalPrompt)
|
||||||
|
|||||||
Reference in New Issue
Block a user