mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27eb73d48b | ||
|
|
685b5011e4 | ||
|
|
15e9eb3639 | ||
|
|
f18e6b9b11 | ||
|
|
40ebc8e942 | ||
|
|
fa3e6d040d | ||
|
|
458e4469e5 | ||
|
|
72c8e7e9f9 | ||
|
|
b9c8e90d98 | ||
|
|
36fcba1280 | ||
|
|
801b5abce3 | ||
|
|
206c3d5479 | ||
|
|
b2903c35ed |
@@ -153,6 +153,7 @@ OpenAI Chat / Responses 在标准化后、current input file 之前,会默认
|
||||
工具调用正例现在优先示范官方 DSML 风格:`<|DSML|tool_calls>` → `<|DSML|invoke name="...">` → `<|DSML|parameter name="...">`。
|
||||
兼容层仍接受旧式纯 `<tool_calls>` wrapper,但提示词会优先要求模型输出官方 DSML 标签,并强调不能只输出 closing wrapper 而漏掉 opening tag。需要注意:这是“兼容 DSML 外壳,内部仍以 XML 解析语义为准”,不是原生 DSML 全链路实现;DSML 标签会在解析入口归一化回现有 XML 标签后继续走同一套 parser。
|
||||
数组参数使用 `<item>...</item>` 子节点表示;当某个参数体只包含 item 子节点时,Go / Node 解析器会把它还原成数组,避免 `questions` / `options` 这类 schema 中要求 array 的参数被误解析成 `{ "item": ... }` 对象。若模型把完整结构化 XML fragment 误包进 CDATA,兼容层会在保护 `content` / `command` 等原文字段的前提下,尝试把非原文字段中的 CDATA XML fragment 还原成 object / array。不过,如果 CDATA 只是单个平面的 XML/HTML 标签,例如 `<b>urgent</b>` 这种行内标记,兼容层会保留原始字符串,不会强行升成 object / array;只有明显表示结构的 CDATA 片段,例如多兄弟节点、嵌套子节点或 `item` 列表,才会触发结构化恢复。
|
||||
在 assistant 最终回包阶段,如果某个 tool 参数在声明 schema 中明确是 `string`,兼容层会在把解析后的 `tool_calls` / `function_call` 重新序列化成 OpenAI / Responses / Claude 可见参数前,递归把该路径上的 number / bool / object / array 统一转成字符串;其中 object / array 会压成紧凑 JSON 字符串。这个保护只对 schema 明确声明为 string 的路径生效,不会改写本来就是 `number` / `boolean` / `object` / `array` 的参数。这样可以兼容 DeepSeek 输出了结构化片段、但上游客户端工具 schema 又严格要求字符串参数的场景(例如 `content`、`prompt`、`path`、`taskId` 等)。
|
||||
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
|
||||
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
|
||||
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
||||
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
||||
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
}
|
||||
|
||||
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
|
||||
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
@@ -19,7 +19,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, toolsRaw)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,19 +9,19 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
|
||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||
// as executable tool calls.
|
||||
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
|
||||
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
||||
return BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
}
|
||||
|
||||
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
|
||||
func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
output = append(output, toResponsesFunctionCallItems(detected, toolsRaw)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
@@ -74,12 +74,13 @@ func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking,
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall) []any {
|
||||
func toResponsesFunctionCallItems(toolCalls []toolcall.ParsedToolCall, toolsRaw any) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(toolCalls, toolsRaw)
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
for _, tc := range normalizedCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/toolcall"
|
||||
)
|
||||
|
||||
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
@@ -13,6 +16,7 @@ func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
nil,
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
@@ -42,6 +46,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
"internal thinking content",
|
||||
"",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
@@ -75,6 +80,7 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
|
||||
`<tool_calls><invoke name="search"><parameter name="q">from-thinking</parameter></invoke></tool_calls>`,
|
||||
"",
|
||||
[]string{"search"},
|
||||
nil,
|
||||
)
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
@@ -86,3 +92,88 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
|
||||
t.Fatalf("expected function_call output, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildChatCompletionWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
obj := BuildChatCompletionWithToolCalls(
|
||||
"chat_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"",
|
||||
[]toolcall.ParsedToolCall{{
|
||||
Name: "Write",
|
||||
Input: map[string]any{
|
||||
"content": map[string]any{"message": "hi"},
|
||||
"taskId": 1,
|
||||
},
|
||||
}},
|
||||
toolsRaw,
|
||||
)
|
||||
choices, _ := obj["choices"].([]map[string]any)
|
||||
message, _ := choices[0]["message"].(map[string]any)
|
||||
toolCalls, _ := message["tool_calls"].([]map[string]any)
|
||||
fn, _ := toolCalls[0]["function"].(map[string]any)
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(fn["arguments"].(string)), &args); err != nil {
|
||||
t.Fatalf("decode arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected taskId stringified by schema, got %#v", args["taskId"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectWithToolCallsCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
obj := BuildResponseObjectWithToolCalls(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"",
|
||||
[]toolcall.ParsedToolCall{{
|
||||
Name: "Write",
|
||||
Input: map[string]any{"content": []any{"a", 1}},
|
||||
}},
|
||||
toolsRaw,
|
||||
)
|
||||
output, _ := obj["output"].([]any)
|
||||
first, _ := output[0].(map[string]any)
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(first["arguments"].(string)), &args); err != nil {
|
||||
t.Fatalf("decode response arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `["a",1]` {
|
||||
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, session)
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, nil, session)
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
|
||||
@@ -21,6 +21,7 @@ type chatStreamRuntime struct {
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
@@ -61,6 +62,7 @@ func newChatStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
) *chatStreamRuntime {
|
||||
@@ -73,6 +75,7 @@ func newChatStreamRuntime(
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
@@ -142,7 +145,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -164,7 +167,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
@@ -320,7 +323,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
|
||||
@@ -26,14 +26,14 @@ type chatNonStreamResult struct {
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
attempts := 0
|
||||
currentResp := resp
|
||||
usagePrompt := finalPrompt
|
||||
accumulatedThinking := ""
|
||||
accumulatedToolDetectionThinking := ""
|
||||
for {
|
||||
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||
detected := detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
||||
result.detectedCalls = len(detected.Calls)
|
||||
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls)
|
||||
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls, toolsRaw)
|
||||
result.finishReason = chatFinishReason(result.body)
|
||||
if !shouldRetryChatNonStream(result, attempts) {
|
||||
h.finishChatNonStreamResult(w, result, attempts, usagePrompt, historySession)
|
||||
@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (chatNonStreamResult, bool) {
|
||||
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (chatNonStreamResult, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -88,7 +88,7 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
|
||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||
}
|
||||
detected := detectAssistantToolCalls(finalText, finalThinking, finalToolDetectionThinking, toolNames)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
return chatNonStreamResult{
|
||||
thinking: finalThinking,
|
||||
toolDetectionThinking: finalToolDetectionThinking,
|
||||
@@ -139,8 +139,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
|
||||
strings.TrimSpace(result.text) == ""
|
||||
}
|
||||
|
||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, historySession)
|
||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -182,7 +182,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -207,7 +207,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
|
||||
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames,
|
||||
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
|
||||
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||
)
|
||||
return streamRuntime, initialType, true
|
||||
|
||||
@@ -144,8 +144,8 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
|
||||
return shared.FilterIncrementalToolCallDeltasByAllowed(deltas, seenNames)
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids)
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
|
||||
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids, toolsRaw)
|
||||
}
|
||||
|
||||
func detectAssistantToolCalls(text, exposedThinking, detectionThinking string, toolNames []string) toolcall.ToolCallParseResult {
|
||||
|
||||
@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -148,7 +148,7 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -176,7 +176,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
||||
writeUpstreamEmptyOutputError(w, finalText, finalThinking, result.ContentFilter)
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
finishReason := "stop"
|
||||
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
||||
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
||||
@@ -189,7 +189,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -230,6 +230,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
searchEnabled,
|
||||
stripReferenceMarkers,
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-v4-flash", "prompt", false, false, nil, nil, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil)
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-v4-pro", "prompt", true, false, nil, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -150,7 +150,7 @@ func TestHandleNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
|
||||
h.handleNonStream(rec, resp, "cid-thinking-tool", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestHandleNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpty(t *test
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleNonStream(rec, resp, "cid-hidden-thinking-tool", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -211,7 +211,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -248,7 +248,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-v4-flash", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -282,7 +282,7 @@ func TestHandleStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstreamIntercep
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid-thinking-stream", "deepseek-v4-pro", "prompt", true, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -319,7 +319,7 @@ func TestHandleStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *testing.T)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid-hidden-thinking-stream", "deepseek-v4-pro", "prompt", false, false, []string{"search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -353,7 +353,7 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil)
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -390,3 +390,64 @@ func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing
|
||||
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamCoercesSchemaDeclaredStringArgumentsOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
line := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
|
||||
return "data: " + string(b)
|
||||
}
|
||||
resp := makeSSEHTTPResponse(
|
||||
line(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-string-protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(asString(fn["arguments"])), &args); err != nil {
|
||||
t.Fatalf("decode streamed tool arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected streamed content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected streamed taskId stringified by schema, got %#v", args["taskId"])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected at least one streamed tool call delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
@@ -26,3 +26,31 @@ func TestReplaceCitationMarkersWithLinksKeepsUnknownIndex(t *testing.T) {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceCitationMarkersWithLinksSupportsReferenceMarker(t *testing.T) {
|
||||
raw := "新闻摘要[reference:1],详情[reference:2]。"
|
||||
links := map[int]string{
|
||||
1: "https://example.com/r1",
|
||||
2: "https://example.com/r2",
|
||||
}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "新闻摘要[1](https://example.com/r1),详情[2](https://example.com/r2)。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceCitationMarkersWithLinksSupportsReferenceZeroBased(t *testing.T) {
|
||||
raw := "来源[reference:0] 与 [reference:1]。"
|
||||
links := map[int]string{
|
||||
1: "https://example.com/first",
|
||||
2: "https://example.com/second",
|
||||
}
|
||||
|
||||
got := replaceCitationMarkersWithLinks(raw, links)
|
||||
want := "来源[0](https://example.com/first) 与 [1](https://example.com/second)。"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,14 +27,14 @@ type responsesNonStreamResult struct {
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
attempts := 0
|
||||
currentResp := resp
|
||||
usagePrompt := finalPrompt
|
||||
accumulatedThinking := ""
|
||||
accumulatedToolDetectionThinking := ""
|
||||
for {
|
||||
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
||||
result.thinking = accumulatedThinking
|
||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||
result.parsed = detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
|
||||
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls)
|
||||
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
|
||||
|
||||
if !shouldRetryResponsesNonStream(result, attempts) {
|
||||
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
|
||||
@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (responsesNonStreamResult, bool) {
|
||||
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -84,7 +84,7 @@ func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *
|
||||
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
||||
}
|
||||
textParsed := detectAssistantToolCalls(sanitizedText, sanitizedThinking, toolDetectionThinking, toolNames)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||
return responsesNonStreamResult{
|
||||
thinking: sanitizedThinking,
|
||||
toolDetectionThinking: toolDetectionThinking,
|
||||
@@ -123,8 +123,8 @@ func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int
|
||||
strings.TrimSpace(result.text) == ""
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolChoice, traceID)
|
||||
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
|
||||
func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) (*responsesStreamRuntime, string, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -184,7 +184,7 @@ func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, resp *htt
|
||||
}
|
||||
streamRuntime := newResponsesStreamRuntime(
|
||||
w, rc, canFlush, responseID, model, finalPrompt, thinkingEnabled, searchEnabled,
|
||||
h.compatStripReferenceMarkers(), toolNames, len(toolNames) > 0,
|
||||
h.compatStripReferenceMarkers(), toolNames, toolsRaw, len(toolNames) > 0,
|
||||
h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||
toolChoice, traceID, func(obj map[string]any) {
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
|
||||
@@ -115,13 +115,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -148,12 +148,12 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -186,6 +186,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
searchEnabled,
|
||||
stripReferenceMarkers,
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
toolChoice,
|
||||
|
||||
@@ -22,6 +22,7 @@ type responsesStreamRuntime struct {
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
traceID string
|
||||
toolChoice promptcompat.ToolChoicePolicy
|
||||
|
||||
@@ -72,6 +73,7 @@ func newResponsesStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
toolChoice promptcompat.ToolChoicePolicy,
|
||||
@@ -89,6 +91,7 @@ func newResponsesStreamRuntime(
|
||||
searchEnabled: searchEnabled,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
|
||||
@@ -220,7 +220,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolstream
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []toolcall.ParsedToolCall) {
|
||||
for idx, tc := range calls {
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
|
||||
for idx, tc := range normalizedCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -109,7 +109,8 @@ func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, fin
|
||||
}
|
||||
}
|
||||
|
||||
for idx, tc := range calls {
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, s.toolsRaw)
|
||||
for idx, tc := range normalizedCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
@@ -57,7 +57,7 @@ func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testin
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_text.done") {
|
||||
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||
@@ -91,7 +91,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
|
||||
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
||||
@@ -130,7 +130,7 @@ func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file", "search"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
@@ -183,7 +183,7 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
Mode: promptcompat.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
@@ -213,7 +213,7 @@ func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
@@ -251,7 +251,7 @@ func TestHandleResponsesStreamPromotesThinkingToolCallsOnFinalizeWithoutMidstrea
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
@@ -288,7 +288,7 @@ func TestHandleResponsesStreamPromotesHiddenThinkingDSMLToolCallsOnFinalize(t *t
|
||||
Mode: promptcompat.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.reasoning.delta") {
|
||||
@@ -317,7 +317,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -344,7 +344,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayloadWhe
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", true, false, []string{"read_file"}, nil, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -366,7 +366,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-flash", "prompt", false, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -410,7 +410,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testin
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, nil, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func TestHandleResponsesNonStreamPromotesThinkingToolCallsWhenTextEmpty(t *testi
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-v4-pro", "prompt", true, false, []string{"read_file"}, nil, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -462,7 +462,7 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
|
||||
Mode: promptcompat.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_hidden", "deepseek-v4-pro", "prompt", false, false, []string{"read_file"}, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for hidden thinking tool calls, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -480,6 +480,53 @@ func TestHandleResponsesNonStreamPromotesHiddenThinkingDSMLToolCallsWhenTextEmpt
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamCoercesSchemaDeclaredStringArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "Write",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{"p": "response/content", "v": v})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
streamBody := sseLine(`<tool_calls><invoke name="Write">{"input":{"content":{"message":"hi"},"taskId":1}}</invoke></tool_calls>`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_string_protect", "deepseek-v4-flash", "prompt", false, false, []string{"Write"}, toolsRaw, promptcompat.DefaultToolChoicePolicy(), "")
|
||||
|
||||
payload, ok := extractSSEEventPayload(rec.Body.String(), "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.function_call_arguments.done payload, body=%s", rec.Body.String())
|
||||
}
|
||||
args := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(asString(payload["arguments"])), &args); err != nil {
|
||||
t.Fatalf("decode streamed response arguments failed: %v", err)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected response taskId stringified by schema, got %#v", args["taskId"])
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
|
||||
@@ -7,22 +7,27 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var citationMarkerPattern = regexp.MustCompile(`(?i)\[citation:\s*(\d+)\]`)
|
||||
var citationMarkerPattern = regexp.MustCompile(`(?i)\[(citation|reference):\s*(\d+)\]`)
|
||||
|
||||
func ReplaceCitationMarkersWithLinks(text string, links map[int]string) string {
|
||||
if strings.TrimSpace(text) == "" || len(links) == 0 {
|
||||
return text
|
||||
}
|
||||
zeroBased := strings.Contains(strings.ToLower(text), "[reference:0]")
|
||||
return citationMarkerPattern.ReplaceAllStringFunc(text, func(match string) string {
|
||||
sub := citationMarkerPattern.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
if len(sub) < 3 {
|
||||
return match
|
||||
}
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(sub[1]))
|
||||
if err != nil || idx <= 0 {
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(sub[2]))
|
||||
if err != nil || idx < 0 {
|
||||
return match
|
||||
}
|
||||
url := strings.TrimSpace(links[idx])
|
||||
lookupIdx := idx
|
||||
if zeroBased {
|
||||
lookupIdx = idx + 1
|
||||
}
|
||||
url := strings.TrimSpace(links[lookupIdx])
|
||||
if url == "" {
|
||||
return match
|
||||
}
|
||||
|
||||
@@ -70,12 +70,13 @@ func FilterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
for i, c := range normalizedCalls {
|
||||
callID := ""
|
||||
if ids != nil {
|
||||
callID = strings.TrimSpace(ids[i])
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
|
||||
{Name: "search", Input: map[string]any{"q": "test"}},
|
||||
})
|
||||
}, nil)
|
||||
if len(formatted) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(formatted))
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
|
||||
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
for _, c := range normalized {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
@@ -23,9 +24,10 @@ func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall, toolsRaw any) []map[string]any {
|
||||
normalized := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
for i, c := range normalized {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
|
||||
266
internal/toolcall/toolcalls_schema_normalize.go
Normal file
266
internal/toolcall/toolcalls_schema_normalize.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func NormalizeParsedToolCallsForSchemas(calls []ParsedToolCall, toolsRaw any) []ParsedToolCall {
|
||||
if len(calls) == 0 {
|
||||
return calls
|
||||
}
|
||||
schemas := buildToolSchemaIndex(toolsRaw)
|
||||
if len(schemas) == 0 {
|
||||
return calls
|
||||
}
|
||||
|
||||
var changedAny bool
|
||||
out := make([]ParsedToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
out[i] = call
|
||||
schema, ok := schemas[strings.ToLower(strings.TrimSpace(call.Name))]
|
||||
if !ok || call.Input == nil {
|
||||
continue
|
||||
}
|
||||
normalized, changed := normalizeToolValueWithSchema(call.Input, schema)
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
changedAny = true
|
||||
if input, ok := normalized.(map[string]any); ok {
|
||||
out[i].Input = input
|
||||
}
|
||||
}
|
||||
if !changedAny {
|
||||
return calls
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildToolSchemaIndex(toolsRaw any) map[string]any {
|
||||
tools, ok := toolsRaw.([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(tools))
|
||||
for _, item := range tools {
|
||||
tool, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, schema := extractToolNameAndSchema(tool)
|
||||
if name == "" || schema == nil {
|
||||
continue
|
||||
}
|
||||
out[strings.ToLower(name)] = schema
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractToolNameAndSchema(tool map[string]any) (string, any) {
|
||||
name := strings.TrimSpace(asStringValue(tool["name"]))
|
||||
schema := tool["parameters"]
|
||||
if schema == nil {
|
||||
schema = tool["input_schema"]
|
||||
}
|
||||
if fn, ok := tool["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asStringValue(fn["name"]))
|
||||
}
|
||||
if schema == nil {
|
||||
schema = fn["parameters"]
|
||||
}
|
||||
if schema == nil {
|
||||
schema = fn["input_schema"]
|
||||
}
|
||||
}
|
||||
return name, schema
|
||||
}
|
||||
|
||||
func normalizeToolValueWithSchema(value any, schema any) (any, bool) {
|
||||
if value == nil || schema == nil {
|
||||
return value, false
|
||||
}
|
||||
schemaMap, ok := schema.(map[string]any)
|
||||
if !ok || len(schemaMap) == 0 {
|
||||
return value, false
|
||||
}
|
||||
if shouldCoerceSchemaToString(schemaMap) {
|
||||
return stringifySchemaValue(value)
|
||||
}
|
||||
if looksLikeObjectSchema(schemaMap) {
|
||||
obj, ok := value.(map[string]any)
|
||||
if !ok || len(obj) == 0 {
|
||||
return value, false
|
||||
}
|
||||
properties, _ := schemaMap["properties"].(map[string]any)
|
||||
additional := schemaMap["additionalProperties"]
|
||||
changed := false
|
||||
out := make(map[string]any, len(obj))
|
||||
for key, current := range obj {
|
||||
next := current
|
||||
var fieldChanged bool
|
||||
if propSchema, ok := properties[key]; ok {
|
||||
next, fieldChanged = normalizeToolValueWithSchema(current, propSchema)
|
||||
} else if additional != nil {
|
||||
next, fieldChanged = normalizeToolValueWithSchema(current, additional)
|
||||
}
|
||||
out[key] = next
|
||||
changed = changed || fieldChanged
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
if looksLikeArraySchema(schemaMap) {
|
||||
arr, ok := value.([]any)
|
||||
if !ok || len(arr) == 0 {
|
||||
return value, false
|
||||
}
|
||||
itemsSchema := schemaMap["items"]
|
||||
if itemsSchema == nil {
|
||||
return value, false
|
||||
}
|
||||
changed := false
|
||||
out := make([]any, len(arr))
|
||||
switch itemSchemas := itemsSchema.(type) {
|
||||
case []any:
|
||||
for i, item := range arr {
|
||||
if i >= len(itemSchemas) {
|
||||
out[i] = item
|
||||
continue
|
||||
}
|
||||
next, itemChanged := normalizeToolValueWithSchema(item, itemSchemas[i])
|
||||
out[i] = next
|
||||
changed = changed || itemChanged
|
||||
}
|
||||
default:
|
||||
for i, item := range arr {
|
||||
next, itemChanged := normalizeToolValueWithSchema(item, itemsSchema)
|
||||
out[i] = next
|
||||
changed = changed || itemChanged
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
return value, false
|
||||
}
|
||||
|
||||
func shouldCoerceSchemaToString(schema map[string]any) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
if isStringConst(schema["const"]) {
|
||||
return true
|
||||
}
|
||||
if isStringEnum(schema["enum"]) {
|
||||
return true
|
||||
}
|
||||
switch v := schema["type"].(type) {
|
||||
case string:
|
||||
return strings.EqualFold(strings.TrimSpace(v), "string")
|
||||
case []any:
|
||||
return isOnlyStringLikeTypes(v)
|
||||
case []string:
|
||||
items := make([]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
items = append(items, item)
|
||||
}
|
||||
return isOnlyStringLikeTypes(items)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeObjectSchema(schema map[string]any) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "object") {
|
||||
return true
|
||||
}
|
||||
if _, ok := schema["properties"].(map[string]any); ok {
|
||||
return true
|
||||
}
|
||||
_, hasAdditional := schema["additionalProperties"]
|
||||
return hasAdditional
|
||||
}
|
||||
|
||||
func looksLikeArraySchema(schema map[string]any) bool {
|
||||
if schema == nil {
|
||||
return false
|
||||
}
|
||||
if typ, ok := schema["type"].(string); ok && strings.EqualFold(strings.TrimSpace(typ), "array") {
|
||||
return true
|
||||
}
|
||||
_, hasItems := schema["items"]
|
||||
return hasItems
|
||||
}
|
||||
|
||||
func isOnlyStringLikeTypes(values []any) bool {
|
||||
if len(values) == 0 {
|
||||
return false
|
||||
}
|
||||
hasString := false
|
||||
for _, item := range values {
|
||||
typ, ok := item.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(typ)) {
|
||||
case "string":
|
||||
hasString = true
|
||||
case "null":
|
||||
continue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return hasString
|
||||
}
|
||||
|
||||
func isStringConst(v any) bool {
|
||||
_, ok := v.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func isStringEnum(v any) bool {
|
||||
values, ok := v.([]any)
|
||||
if !ok || len(values) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, item := range values {
|
||||
if _, ok := item.(string); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringifySchemaValue(value any) (any, bool) {
|
||||
if value == nil {
|
||||
return value, false
|
||||
}
|
||||
if s, ok := value.(string); ok {
|
||||
return s, false
|
||||
}
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return value, false
|
||||
}
|
||||
return string(b), true
|
||||
}
|
||||
|
||||
func asStringValue(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
112
internal/toolcall/toolcalls_schema_normalize_test.go
Normal file
112
internal/toolcall/toolcalls_schema_normalize_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package toolcall
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeParsedToolCallsForSchemasCoercesDeclaredStringFieldsRecursively(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "TaskUpdate",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
"payload": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"tags": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "string"},
|
||||
},
|
||||
"count": map[string]any{"type": "number"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
calls := []ParsedToolCall{{
|
||||
Name: "TaskUpdate",
|
||||
Input: map[string]any{
|
||||
"taskId": 1,
|
||||
"payload": map[string]any{
|
||||
"content": map[string]any{"text": "hello"},
|
||||
"tags": []any{1, true, map[string]any{"k": "v"}},
|
||||
"count": 2,
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized call, got %#v", got)
|
||||
}
|
||||
if got[0].Input["taskId"] != "1" {
|
||||
t.Fatalf("expected taskId coerced to string, got %#v", got[0].Input["taskId"])
|
||||
}
|
||||
payload, ok := got[0].Input["payload"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected payload object, got %#v", got[0].Input["payload"])
|
||||
}
|
||||
if payload["content"] != `{"text":"hello"}` {
|
||||
t.Fatalf("expected nested content coerced to json string, got %#v", payload["content"])
|
||||
}
|
||||
if payload["count"] != 2 {
|
||||
t.Fatalf("expected non-string count unchanged, got %#v", payload["count"])
|
||||
}
|
||||
tags, ok := payload["tags"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected tags slice, got %#v", payload["tags"])
|
||||
}
|
||||
wantTags := []any{"1", "true", `{"k":"v"}`}
|
||||
if !reflect.DeepEqual(tags, wantTags) {
|
||||
t.Fatalf("unexpected normalized tags: got %#v want %#v", tags, wantTags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeParsedToolCallsForSchemasSupportsDirectToolSchemaShape(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"name": "Write",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
calls := []ParsedToolCall{{Name: "Write", Input: map[string]any{"content": []any{"a", 1}}}}
|
||||
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
if got[0].Input["content"] != `["a",1]` {
|
||||
t.Fatalf("expected direct-schema content coerced to string, got %#v", got[0].Input["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeParsedToolCallsForSchemasLeavesAmbiguousUnionUnchanged(t *testing.T) {
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "TaskUpdate",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"taskId": map[string]any{"type": []any{"string", "integer"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
calls := []ParsedToolCall{{Name: "TaskUpdate", Input: map[string]any{"taskId": 1}}}
|
||||
got := NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
|
||||
if got[0].Input["taskId"] != 1 {
|
||||
t.Fatalf("expected ambiguous union to stay unchanged, got %#v", got[0].Input["taskId"])
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
|
||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}}, nil)
|
||||
if len(formatted) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(formatted))
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
|
||||
Reference in New Issue
Block a user