Merge pull request #342 from shern-point/fix/tool-string-schema-protection

Fix/tool string schema protection
This commit is contained in:
CJACK.
2026-04-28 16:37:44 +08:00
committed by GitHub
23 changed files with 672 additions and 80 deletions

View File

@@ -153,6 +153,7 @@ OpenAI Chat / Responses 在标准化后、current input file 之前,会默认
工具调用正例现在优先示范官方 DSML 风格:`<|DSML|tool_calls>``<|DSML|invoke name="...">``<|DSML|parameter name="...">`
兼容层仍接受旧式纯 `<tool_calls>` wrapper但提示词会优先要求模型输出官方 DSML 标签,并强调不能只输出 closing wrapper 而漏掉 opening tag。需要注意这是“兼容 DSML 外壳,内部仍以 XML 解析语义为准”,不是原生 DSML 全链路实现DSML 标签会在解析入口归一化回现有 XML 标签后继续走同一套 parser。
数组参数使用 `<item>...</item>` 子节点表示;当某个参数体只包含 item 子节点时Go / Node 解析器会把它还原成数组,避免 `questions` / `options` 这类 schema 中要求 array 的参数被误解析成 `{ "item": ... }` 对象。若模型把完整结构化 XML fragment 误包进 CDATA兼容层会在保护 `content` / `command` 等原文字段的前提下,尝试把非原文字段中的 CDATA XML fragment 还原成 object / array。不过如果 CDATA 只是单个平面的 XML/HTML 标签,例如 `<b>urgent</b>` 这种行内标记,兼容层会保留原始字符串,不会强行升成 object / array只有明显表示结构的 CDATA 片段,例如多兄弟节点、嵌套子节点或 `item` 列表,才会触发结构化恢复。
在 assistant 最终回包阶段,如果某个 tool 参数在声明 schema 中明确是 `string`,兼容层会在把解析后的 `tool_calls` / `function_call` 重新序列化成 OpenAI / Responses / Claude 可见参数前,递归把该路径上的 number / bool / object / array 统一转成字符串;其中 object / array 会压成紧凑 JSON 字符串。这个保护只对 schema 明确声明为 string 的路径生效,不会改写本来就是 `number` / `boolean` / `object` / `array` 的参数。这样可以兼容 DeepSeek 输出了结构化片段、但上游客户端工具 schema 又严格要求字符串参数的场景(例如 `content``prompt``path``taskId` 等)。
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command``exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。

View File

@@ -6,12 +6,12 @@ import (
"time"
)
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
}
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" {
@@ -19,7 +19,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, toolsRaw)
messageObj["content"] = nil
}

View File

@@ -9,19 +9,19 @@ import (
"github.com/google/uuid"
)
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
// Strict mode: only standalone, structured tool-call payloads are treated
// as executable tool calls.
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls)
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
}
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
exposedOutputText := finalText
output := make([]any, 0, 2)
if len(detected) > 0 {
exposedOutputText = ""
output = append(output, toResponsesFunctionCallItems(detected)...)
output = append(output, toResponsesFunctionCallItems(detected, toolsRaw)...)
} else {
content := make([]any, 0, 2)
if finalThinking != "" {
@@ -74,12 +74,13 @@ func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking,
}
}
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall) []any {
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall, toolsRaw any) []any {
if len(toolCalls) == 0 {
return nil
}
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(toolCalls, toolsRaw)
out := make([]any, 0, len(toolCalls))
for _, tc := range toolCalls {
for _, tc := range normalizedCalls {
if strings.TrimSpace(tc.Name) == "" {
continue
}

View File

@@ -1,8 +1,11 @@
package openai
import (
"encoding/json"
"strings"
"testing"
"ds2api/internal/toolcall"
)
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
@@ -13,6 +16,7 @@ func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
"",
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
[]string{"search"},
nil,
)
outputText, _ := obj["output_text"].(string)
@@ -42,6 +46,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
"internal thinking content",
"",
nil,
nil,
)
outputText, _ := obj["output_text"].(string)
@@ -75,6 +80,7 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
`<tool_calls><invoke name="search"><parameter name="q">from-thinking</parameter></invoke></tool_calls>`,
"",
[]string{"search"},
nil,
)
output, _ := obj["output"].([]any)
@@ -86,3 +92,88 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
t.Fatalf("expected function_call output, got %#v", first["type"])
}
}
func TestBuildChatCompletionWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
toolsRaw := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "Write",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
"taskId": map[string]any{"type": "string"},
},
},
},
},
}
obj := BuildChatCompletionWithToolCalls(
"chat_test",
"gpt-4o",
"prompt",
"",
"",
[]toolcall.ParsedToolCall{{
Name: "Write",
Input: map[string]any{
"content": map[string]any{"message": "hi"},
"taskId": 1,
},
}},
toolsRaw,
)
choices, _ := obj["choices"].([]map[string]any)
message, _ := choices[0]["message"].(map[string]any)
toolCalls, _ := message["tool_calls"].([]map[string]any)
fn, _ := toolCalls[0]["function"].(map[string]any)
args := map[string]any{}
if err := json.Unmarshal([]byte(fn["arguments"].(string)), &args); err != nil {
t.Fatalf("decode arguments failed: %v", err)
}
if args["content"] != `{"message":"hi"}` {
t.Fatalf("expected content stringified by schema, got %#v", args["content"])
}
if args["taskId"] != "1" {
t.Fatalf("expected taskId stringified by schema, got %#v", args["taskId"])
}
}
func TestBuildResponseObjectWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
toolsRaw := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "Write",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
},
},
}
obj := BuildResponseObjectWithToolCalls(
"resp_test",
"gpt-4o",
"prompt",
"",
"",
[]toolcall.ParsedToolCall{{
Name: "Write",
Input: map[string]any{"content": []any{"a", 1}},
}},
toolsRaw,
)
output, _ := obj["output"].([]any)
first, _ := output[0].(map[string]any)
args := map[string]any{}
if err := json.Unmarshal([]byte(first["arguments"].(string)), &args); err != nil {
t.Fatalf("decode response arguments failed: %v", err)
}
if args["content"] != `["a",1]` {
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
}
}

View File

@@ -194,7 +194,7 @@ func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
rec := httptest.NewRecorder()
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, session)
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, nil, session)
snapshot, err := historyStore.Snapshot()
if err != nil {

View File

@@ -21,6 +21,7 @@ type chatStreamRuntime struct {
model string
finalPrompt string
toolNames []string
toolsRaw any
thinkingEnabled bool
searchEnabled bool
@@ -61,6 +62,7 @@ func newChatStreamRuntime(
searchEnabled bool,
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
bufferToolContent bool,
emitEarlyToolDeltas bool,
) *chatStreamRuntime {
@@ -73,6 +75,7 @@ func newChatStreamRuntime(
model: model,
finalPrompt: finalPrompt,
toolNames: toolNames,
toolsRaw: toolsRaw,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
stripReferenceMarkers: stripReferenceMarkers,
@@ -142,7 +145,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
delta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
}
if !s.firstChunkSent {
delta["role"] = "assistant"
@@ -164,7 +167,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
@@ -320,7 +323,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"

View File

@@ -26,14 +26,14 @@ type chatNonStreamResult struct {
responseMessageID int
}
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
attempts := 0
currentResp := resp
usagePrompt := finalPrompt
accumulatedThinking := ""
accumulatedToolDetectionThinking := ""
for {
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
if !ok {
return
}
@@ -43,7 +43,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
result.toolDetectionThinking = accumulatedToolDetectionThinking
detected := detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
result.detectedCalls = len(detected.Calls)
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls)
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls, toolsRaw)
result.finishReason = chatFinishReason(result.body)
if !shouldRetryChatNonStream(result, attempts) {
h.finishChatNonStreamResult(w, result, attempts, usagePrompt, historySession)
@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
}
}
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (chatNonStreamResult, bool) {
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (chatNonStreamResult, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -88,7 +88,7 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
}
detected := detectAssistantToolCalls(finalText, finalThinking, finalToolDetectionThinking, toolNames)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
return chatNonStreamResult{
thinking: finalThinking,
toolDetectionThinking: finalToolDetectionThinking,
@@ -139,8 +139,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
strings.TrimSpace(result.text) == ""
}
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, historySession)
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
if !ok {
return
}
@@ -182,7 +182,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
}
}
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -207,7 +207,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
}
streamRuntime := newChatStreamRuntime(
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames,
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
)
return streamRuntime, initialType, true

View File

@@ -144,8 +144,8 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
return shared.FilterIncrementalToolCallDeltasByAllowed(deltas, seenNames)
}
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids)
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids, toolsRaw)
}
func detectAssistantToolCalls(text, exposedThinking, detectionThinking string, toolNames []string) toolcall.ToolCallParseResult {

View File

@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
return
}
if stdReq.Stream {
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
return
}
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
}
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
@@ -148,7 +148,7 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
}
}
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -176,7 +176,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
writeUpstreamEmptyOutputError(w, finalText, finalThinking, result.ContentFilter)
return
}
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
finishReason := "stop"
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
@@ -189,7 +189,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -230,6 +230,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
searchEnabled,
stripReferenceMarkers,
toolNames,
toolsRaw,
bufferToolContent,
emitEarlyToolDeltas,
)

View File

@@ -93,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil)
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -112,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil)
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -131,7 +131,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil)
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil, nil)
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -150,7 +150,7 @@ func TestHandleNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -181,7 +181,7 @@ func TestHandleNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpty(t *test
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -211,7 +211,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
@@ -248,7 +248,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
@@ -282,7 +282,7 @@ func TestHandleStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstreamIntercep
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
@@ -319,7 +319,7 @@ func TestHandleStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *testing.T)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
@@ -353,7 +353,7 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil)
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
@@ -390,3 +390,64 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
}
}
func TestHandleStreamCoercesSchemaDeclaredStringArgumentsOnFinalize(t *testing.T) {
h := &Handler{}
line := func(v string) string {
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
return "data: " + string(b)
}
resp := makeSSEHTTPResponse(
line(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
toolsRaw := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "Write",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
"taskId": map[string]any{"type": "string"},
},
},
},
},
}
h.handleStream(rec, req, resp, "cid-string-protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
if len(toolCalls) == 0 {
continue
}
call, _ := toolCalls[0].(map[string]any)
fn, _ := call["function"].(map[string]any)
args := map[string]any{}
if err := json.Unmarshal([]byte(asString(fn["arguments"])), &args); err != nil {
t.Fatalf("decode streamed tool arguments failed: %v", err)
}
if args["content"] != `{"message":"hi"}` {
t.Fatalf("expected streamed content stringified by schema, got %#v", args["content"])
}
if args["taskId"] != "1" {
t.Fatalf("expected streamed taskId stringified by schema, got %#v", args["taskId"])
}
return
}
}
t.Fatalf("expected at least one streamed tool call delta, body=%s", rec.Body.String())
}

View File

@@ -27,14 +27,14 @@ type responsesNonStreamResult struct {
responseMessageID int
}
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
attempts := 0
currentResp := resp
usagePrompt := finalPrompt
accumulatedThinking := ""
accumulatedToolDetectionThinking := ""
for {
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
if !ok {
return
}
@@ -43,7 +43,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
result.thinking = accumulatedThinking
result.toolDetectionThinking = accumulatedToolDetectionThinking
result.parsed = detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls)
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
if !shouldRetryResponsesNonStream(result, attempts) {
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
}
}
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (responsesNonStreamResult, bool) {
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -84,7 +84,7 @@ func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
}
textParsed := detectAssistantToolCalls(sanitizedText, sanitizedThinking, toolDetectionThinking, toolNames)
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
return responsesNonStreamResult{
thinking: sanitizedThinking,
toolDetectionThinking: toolDetectionThinking,
@@ -123,8 +123,8 @@ func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int
strings.TrimSpace(result.text) == ""
}
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolChoice, traceID)
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
if !ok {
return
}
@@ -165,7 +165,7 @@ func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.
}
}
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -184,7 +184,7 @@ func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *htt
}
streamRuntime := newResponsesStreamRuntime(
w, rc, canFlush, responseID, model, finalPrompt, thinkingEnabled, searchEnabled,
h.compatStripReferenceMarkers(), toolNames, len(toolNames) > 0,
h.compatStripReferenceMarkers(), toolNames, toolsRaw, len(toolNames) > 0,
h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
toolChoice, traceID, func(obj map[string]any) {
h.getResponseStore().put(owner, responseID, obj)

View File

@@ -115,13 +115,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if stdReq.Stream {
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
return
}
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -148,12 +148,12 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
return
}
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
h.getResponseStore().put(owner, responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
}
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -186,6 +186,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
searchEnabled,
stripReferenceMarkers,
toolNames,
toolsRaw,
bufferToolContent,
emitEarlyToolDeltas,
toolChoice,

View File

@@ -22,6 +22,7 @@ type responsesStreamRuntime struct {
model string
finalPrompt string
toolNames []string
toolsRaw any
traceID string
toolChoice promptcompat.ToolChoicePolicy
@@ -72,6 +73,7 @@ func newResponsesStreamRuntime(
searchEnabled bool,
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
bufferToolContent bool,
emitEarlyToolDeltas bool,
toolChoice promptcompat.ToolChoicePolicy,
@@ -89,6 +91,7 @@ func newResponsesStreamRuntime(
searchEnabled: searchEnabled,
stripReferenceMarkers: stripReferenceMarkers,
toolNames: toolNames,
toolsRaw: toolsRaw,
bufferToolContent: bufferToolContent,
emitEarlyToolDeltas: emitEarlyToolDeltas,
streamToolCallIDs: map[int]string{},

View File

@@ -220,7 +220,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolstream
}
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []toolcall.ParsedToolCall) {
for idx, tc := range calls {
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
for idx, tc := range normalizedCalls {
if strings.TrimSpace(tc.Name) == "" {
continue
}

View File

@@ -109,7 +109,8 @@ func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, fin
}
}
for idx, tc := range calls {
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
for idx, tc := range normalizedCalls {
if strings.TrimSpace(tc.Name) == "" {
continue
}

View File

@@ -27,7 +27,7 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning.delta") {
@@ -57,7 +57,7 @@ func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testin
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_text.done") {
t.Fatalf("expected response.output_text.done payload, body=%s", body)
@@ -91,7 +91,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
@@ -130,7 +130,7 @@ func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
@@ -183,7 +183,7 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
Mode: promptcompat.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
@@ -213,7 +213,7 @@ func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
@@ -251,7 +251,7 @@ func TestHandleResponsesStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstrea
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning.delta") {
@@ -288,7 +288,7 @@ func TestHandleResponsesStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *t
Mode: promptcompat.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
body := rec.Body.String()
if strings.Contains(body, "event: response.reasoning.delta") {
@@ -317,7 +317,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -344,7 +344,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayloadWhe
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, policy, "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, nil, policy, "")
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -366,7 +366,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
)),
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -388,7 +388,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
)),
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -410,7 +410,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testin
)),
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -432,7 +432,7 @@ func TestHandleResponsesNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testi
)),
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -462,7 +462,7 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
Mode: promptcompat.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
}
@@ -480,6 +480,53 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
}
}
func TestHandleResponsesStreamCoercesSchemaDeclaredStringArguments(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
toolsRaw := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "Write",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
"taskId": map[string]any{"type": "string"},
},
},
},
},
}
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_string_protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, promptcompat.DefaultToolChoicePolicy(), "")
payload, ok := extractSSEEventPayload(rec.Body.String(), "response.function_call_arguments.done")
if !ok {
t.Fatalf("expected response.function_call_arguments.done payload, body=%s", rec.Body.String())
}
args := map[string]any{}
if err := json.Unmarshal([]byte(asString(payload["arguments"])), &args); err != nil {
t.Fatalf("decode streamed response arguments failed: %v", err)
}
if args["content"] != `{"message":"hi"}` {
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
}
if args["taskId"] != "1" {
t.Fatalf("expected response taskId stringified by schema, got %#v", args["taskId"])
}
}
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false

View File

@@ -70,12 +70,13 @@ func FilterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
return out
}
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
if len(calls) == 0 {
return nil
}
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
for i, c := range normalizedCalls {
callID := ""
if ids != nil {
callID = strings.TrimSpace(ids[i])

View File

@@ -9,7 +9,7 @@ import (
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
{Name: "search", Input: map[string]any{"q": "test"}},
})
}, nil)
if len(formatted) != 1 {
t.Fatalf("expected 1, got %d", len(formatted))
}

View File

@@ -7,9 +7,10 @@ import (
"github.com/google/uuid"
)
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
func FormatOpenAIToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
out := make([]map[string]any, 0, len(calls))
for _, c := range calls {
for _, c := range normalized {
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
@@ -23,9 +24,10 @@ func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
return out
}
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
for i, c := range normalized {
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"index": i,

View File

@@ -0,0 +1,266 @@
package toolcall
import (
"encoding/json"
"strings"
)
func NormalizeParsedToolCallsForSchemas(calls []ParsedToolCall, toolsRaw any) []ParsedToolCall {
if len(calls) == 0 {
return calls
}
schemas := buildToolSchemaIndex(toolsRaw)
if len(schemas) == 0 {
return calls
}
var changedAny bool
out := make([]ParsedToolCall, len(calls))
for i, call := range calls {
out[i] = call
schema, ok := schemas[strings.ToLower(strings.TrimSpace(call.Name))]
if !ok || call.Input == nil {
continue
}
normalized, changed := normalizeToolValueWithSchema(call.Input, schema)
if !changed {
continue
}
changedAny = true
if input, ok := normalized.(map[string]any); ok {
out[i].Input = input
}
}
if !changedAny {
return calls
}
return out
}
func buildToolSchemaIndex(toolsRaw any) map[string]any {
tools, ok := toolsRaw.([]any)
if !ok || len(tools) == 0 {
return nil
}
out := make(map[string]any, len(tools))
for _, item := range tools {
tool, ok := item.(map[string]any)
if !ok {
continue
}
name, schema := extractToolNameAndSchema(tool)
if name == "" || schema == nil {
continue
}
out[strings.ToLower(name)] = schema
}
if len(out) == 0 {
return nil
}
return out
}
func extractToolNameAndSchema(tool map[string]any) (string, any) {
name := strings.TrimSpace(asStringValue(tool["name"]))
schema := tool["parameters"]
if schema == nil {
schema = tool["input_schema"]
}
if fn, ok := tool["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(asStringValue(fn["name"]))
}
if schema == nil {
schema = fn["parameters"]
}
if schema == nil {
schema = fn["input_schema"]
}
}
return name, schema
}
func normalizeToolValueWithSchema(value any, schema any) (any, bool) {
if value == nil || schema == nil {
return value, false
}
schemaMap, ok := schema.(map[string]any)
if !ok || len(schemaMap) == 0 {
return value, false
}
if shouldCoerceSchemaToString(schemaMap) {
return stringifySchemaValue(value)
}
if looksLikeObjectSchema(schemaMap) {
obj, ok := value.(map[string]any)
if !ok || len(obj) == 0 {
return value, false
}
properties, _ := schemaMap["properties"].(map[string]any)
additional := schemaMap["additionalProperties"]
changed := false
out := make(map[string]any, len(obj))
for key, current := range obj {
next := current
var fieldChanged bool
if propSchema, ok := properties[key]; ok {
next, fieldChanged = normalizeToolValueWithSchema(current, propSchema)
} else if additional != nil {
next, fieldChanged = normalizeToolValueWithSchema(current, additional)
}
out[key] = next
changed = changed || fieldChanged
}
if !changed {
return value, false
}
return out, true
}
if looksLikeArraySchema(schemaMap) {
arr, ok := value.([]any)
if !ok || len(arr) == 0 {
return value, false
}
itemsSchema := schemaMap["items"]
if itemsSchema == nil {
return value, false
}
changed := false
out := make([]any, len(arr))
switch itemSchemas := itemsSchema.(type) {
case []any:
for i, item := range arr {
if i >= len(itemSchemas) {
out[i] = item
continue
}
next, itemChanged := normalizeToolValueWithSchema(item, itemSchemas[i])
out[i] = next
changed = changed || itemChanged
}
default:
for i, item := range arr {
next, itemChanged := normalizeToolValueWithSchema(item, itemsSchema)
out[i] = next
changed = changed || itemChanged
}
}
if !changed {
return value, false
}
return out, true
}
return value, false
}
func shouldCoerceSchemaToString(schema map[string]any) bool {
if schema == nil {
return false
}
if isStringConst(schema["const"]) {
return true
}
if isStringEnum(schema["enum"]) {
return true
}
switch v := schema["type"].(type) {
case string:
return strings.EqualFold(strings.TrimSpace(v), "string")
case []any:
return isOnlyStringLikeTypes(v)
case []string:
items := make([]any, 0, len(v))
for _, item := range v {
items = append(items, item)
}
return isOnlyStringLikeTypes(items)
default:
return false
}
}
func looksLikeObjectSchema(schema map[string]any) bool {
if schema == nil {
return false
}
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "object") {
return true
}
if _, ok := schema["properties"].(map[string]any); ok {
return true
}
_, hasAdditional := schema["additionalProperties"]
return hasAdditional
}
func looksLikeArraySchema(schema map[string]any) bool {
if schema == nil {
return false
}
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "array") {
return true
}
_, hasItems := schema["items"]
return hasItems
}
func isOnlyStringLikeTypes(values []any) bool {
if len(values) == 0 {
return false
}
hasString := false
for _, item := range values {
typ, ok := item.(string)
if !ok {
return false
}
switch strings.ToLower(strings.TrimSpace(typ)) {
case "string":
hasString = true
case "null":
continue
default:
return false
}
}
return hasString
}
func isStringConst(v any) bool {
_, ok := v.(string)
return ok
}
func isStringEnum(v any) bool {
values, ok := v.([]any)
if !ok || len(values) == 0 {
return false
}
for _, item := range values {
if _, ok := item.(string); !ok {
return false
}
}
return true
}
func stringifySchemaValue(value any) (any, bool) {
if value == nil {
return value, false
}
if s, ok := value.(string); ok {
return s, false
}
b, err := json.Marshal(value)
if err != nil {
return value, false
}
return string(b), true
}
func asStringValue(v any) string {
if s, ok := v.(string); ok {
return s
}
return ""
}

View File

@@ -0,0 +1,112 @@
package toolcall
import (
"reflect"
"testing"
)
func TestNormalizeParsedToolCallsForSchemasCoercesDeclaredStringFieldsRecursively(t *testing.T) {
toolsRaw := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "TaskUpdate",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"taskId": map[string]any{"type": "string"},
"payload": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
"tags": map[string]any{
"type": "array",
"items": map[string]any{"type": "string"},
},
"count": map[string]any{"type": "number"},
},
},
},
},
},
},
}
calls := []ParsedToolCall{{
Name: "TaskUpdate",
Input: map[string]any{
"taskId": 1,
"payload": map[string]any{
"content": map[string]any{"text": "hello"},
"tags": []any{1, true, map[string]any{"k": "v"}},
"count": 2,
},
},
}}
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
if len(got) != 1 {
t.Fatalf("expected one normalized call, got %#v", got)
}
if got[0].Input["taskId"] != "1" {
t.Fatalf("expected taskId coerced to string, got %#v", got[0].Input["taskId"])
}
payload, ok := got[0].Input["payload"].(map[string]any)
if !ok {
t.Fatalf("expected payload object, got %#v", got[0].Input["payload"])
}
if payload["content"] != `{"text":"hello"}` {
t.Fatalf("expected nested content coerced to json string, got %#v", payload["content"])
}
if payload["count"] != 2 {
t.Fatalf("expected non-string count unchanged, got %#v", payload["count"])
}
tags, ok := payload["tags"].([]any)
if !ok {
t.Fatalf("expected tags slice, got %#v", payload["tags"])
}
wantTags := []any{"1", "true", `{"k":"v"}`}
if !reflect.DeepEqual(tags, wantTags) {
t.Fatalf("unexpected normalized tags: got %#v want %#v", tags, wantTags)
}
}
func TestNormalizeParsedToolCallsForSchemasSupportsDirectToolSchemaShape(t *testing.T) {
toolsRaw := []any{
map[string]any{
"name": "Write",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
},
}
calls := []ParsedToolCall{{Name: "Write", Input: map[string]any{"content": []any{"a", 1}}}}
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
if got[0].Input["content"] != `["a",1]` {
t.Fatalf("expected direct-schema content coerced to string, got %#v", got[0].Input["content"])
}
}
func TestNormalizeParsedToolCallsForSchemasLeavesAmbiguousUnionUnchanged(t *testing.T) {
toolsRaw := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "TaskUpdate",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"taskId": map[string]any{"type": []any{"string", "integer"}},
},
},
},
},
}
calls := []ParsedToolCall{{Name: "TaskUpdate", Input: map[string]any{"taskId": 1}}}
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
if got[0].Input["taskId"] != 1 {
t.Fatalf("expected ambiguous union to stay unchanged, got %#v", got[0].Input["taskId"])
}
}

View File

@@ -6,7 +6,7 @@ import (
)
func TestFormatOpenAIToolCalls(t *testing.T) {
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}}, nil)
if len(formatted) != 1 {
t.Fatalf("expected 1, got %d", len(formatted))
}

View File

@@ -20,7 +20,7 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
messageObj["content"] = nil
}
promptTokens := EstimateTokens(finalPrompt)