mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-11 03:37:40 +08:00
feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.
This commit is contained in:
@@ -33,6 +33,7 @@ type chatStreamRuntime struct {
|
||||
|
||||
toolSieve toolStreamSieveState
|
||||
streamToolCallIDs map[int]string
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
@@ -65,6 +66,7 @@ func newChatStreamRuntime(
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamToolNames: map[int]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +213,11 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -3,11 +3,18 @@ package openai
|
||||
import "net/http"
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeOpenAIErrorWithCode(w, status, message, "")
|
||||
}
|
||||
|
||||
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
|
||||
if code == "" {
|
||||
code = openAIErrorCode(status)
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": openAIErrorCode(status),
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -10,9 +10,23 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
|
||||
if policy.IsNone() {
|
||||
return messages, nil
|
||||
}
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
isAllowed := func(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return false
|
||||
}
|
||||
if len(policy.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := policy.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
@@ -25,8 +39,9 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
name = strings.TrimSpace(name)
|
||||
if !isAllowed(name) {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
@@ -39,6 +54,13 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
if policy.Mode == util.ToolChoiceRequired {
|
||||
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n6) Do not call any other tool."
|
||||
}
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
@@ -85,6 +107,33 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
|
||||
return out
|
||||
}
|
||||
|
||||
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := namesToSet(allowedNames)
|
||||
out := make([]toolCallDelta, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
}
|
||||
}
|
||||
seenNames[d.Index] = d.Name
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(seenNames[d.Index])
|
||||
if name == "" || name == "__blocked__" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -181,7 +181,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -197,16 +197,16 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,7 +375,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -390,29 +390,14 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("expected raw tool_calls json to remain in content for unknown schema name: %s", rec.Body.String())
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,18 @@ package openai
|
||||
|
||||
import (
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy())
|
||||
}
|
||||
|
||||
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
messages, toolNames = injectToolPrompt(messages, tools, toolPolicy)
|
||||
}
|
||||
return deepseek.MessagesPrepare(messages), toolNames
|
||||
}
|
||||
|
||||
@@ -73,6 +73,32 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_999",
|
||||
"name": "search",
|
||||
"arguments": `{"q":"golang"}`,
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_999",
|
||||
"output": map[string]any{"ok": true},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||
}
|
||||
toolMsg, _ := msgs[1].(map[string]any)
|
||||
if toolMsg["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", toolMsg)
|
||||
}
|
||||
if toolMsg["name"] != "search" {
|
||||
t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -67,7 +69,8 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
|
||||
traceID := requestTraceID(r)
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
@@ -96,13 +99,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -110,12 +113,26 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
|
||||
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
|
||||
|
||||
callCount := len(textParsed.Calls)
|
||||
if callCount == 0 {
|
||||
callCount = len(thinkingParsed.Calls)
|
||||
}
|
||||
if toolChoice.IsRequired() && callCount == 0 {
|
||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -148,6 +165,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
toolChoice,
|
||||
traceID,
|
||||
func(obj map[string]any) {
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
},
|
||||
@@ -169,3 +188,16 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) {
|
||||
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[responses] rejected tool calls by policy",
|
||||
"trace_id", strings.TrimSpace(traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", policy.Mode,
|
||||
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,15 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
return normalizeResponsesInputItemWithState(m, nil)
|
||||
}
|
||||
|
||||
func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -69,6 +75,15 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" {
|
||||
if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" {
|
||||
out["name"] = inferred
|
||||
} else {
|
||||
config.Logger.Warn(
|
||||
"[responses] unable to backfill tool result name from call_id",
|
||||
"call_id", callID,
|
||||
)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
@@ -111,6 +126,9 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil {
|
||||
callNameByID[callID] = name
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
|
||||
@@ -59,6 +59,7 @@ func normalizeResponsesInputArray(items []any) []any {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
callNameByID := map[string]string{}
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
@@ -71,7 +72,7 @@ func normalizeResponsesInputArray(items []any) []any {
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(x); msg != nil {
|
||||
if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
@@ -19,6 +20,8 @@ type responsesStreamRuntime struct {
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
traceID string
|
||||
toolChoice util.ToolChoicePolicy
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
@@ -32,11 +35,19 @@ type responsesStreamRuntime struct {
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
visibleText strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionDone map[int]bool
|
||||
functionAdded map[int]bool
|
||||
functionNames map[int]string
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
messageItemID string
|
||||
messageAdded bool
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -53,6 +64,8 @@ func newResponsesStreamRuntime(
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
toolChoice util.ToolChoicePolicy,
|
||||
traceID string,
|
||||
persistResponse func(obj map[string]any),
|
||||
) *responsesStreamRuntime {
|
||||
return &responsesStreamRuntime{
|
||||
@@ -70,7 +83,11 @@ func newResponsesStreamRuntime(
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
functionAdded: map[int]bool{},
|
||||
functionNames: map[int]string{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
toolChoice: toolChoice,
|
||||
traceID: traceID,
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
@@ -78,36 +95,59 @@ func newResponsesStreamRuntime(
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
|
||||
}
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
// Compatibility fallback: some streams only emit incremental tool deltas.
|
||||
// Ensure final function_call_arguments.done is emitted at least once.
|
||||
if s.toolCallsEmitted {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
|
||||
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
|
||||
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
|
||||
detected := textParsed.Calls
|
||||
if len(detected) == 0 {
|
||||
detected = thinkingParsed.Calls
|
||||
}
|
||||
s.logToolPolicyRejections(textParsed, thinkingParsed)
|
||||
|
||||
if len(detected) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitToolCallsDone(detected)
|
||||
} else {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
|
||||
s.closeMessageItem()
|
||||
|
||||
if s.toolChoice.IsRequired() && !s.hasFunctionCallDone() {
|
||||
s.failed = true
|
||||
message := "tool_choice requires at least one valid tool call."
|
||||
failedResp := map[string]any{
|
||||
"id": s.responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"model": s.model,
|
||||
"status": "failed",
|
||||
"output": []any{},
|
||||
"output_text": "",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": "tool_choice_violation",
|
||||
"param": nil,
|
||||
},
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(failedResp)
|
||||
}
|
||||
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation"))
|
||||
s.sendDone()
|
||||
return
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
s.alignCompletedOutputCallIDs(obj)
|
||||
}
|
||||
if s.toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -115,6 +155,32 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[responses] rejected tool calls by policy",
|
||||
"trace_id", strings.TrimSpace(s.traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", s.toolChoice.Mode,
|
||||
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
|
||||
)
|
||||
}
|
||||
logRejected(textParsed, "text")
|
||||
logRejected(thinkingParsed, "thinking")
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||
for _, done := range s.functionDone {
|
||||
if done {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
@@ -138,7 +204,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
@@ -147,7 +212,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
s.emitTextDelta(p.Text)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
|
||||
@@ -6,7 +6,18 @@ import (
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) nextSequence() int {
|
||||
s.sequence++
|
||||
return s.sequence
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
if _, ok := payload["sequence_number"]; !ok {
|
||||
payload["sequence_number"] = s.nextSequence()
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
@@ -31,22 +42,20 @@ func (s *responsesStreamRuntime) sendDone() {
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
s.emitTextDelta(evt.Content)
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
s.emitFunctionCallDeltaEvents(filtered)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitToolCallsDone(evt.ToolCalls)
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,25 +11,101 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
|
||||
if len(calls) == 0 {
|
||||
func (s *responsesStreamRuntime) ensureMessageItemID() string {
|
||||
if strings.TrimSpace(s.messageItemID) != "" {
|
||||
return s.messageItemID
|
||||
}
|
||||
s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.messageItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) messageOutputIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageItemAdded() {
|
||||
if s.messageAdded {
|
||||
return
|
||||
}
|
||||
sig := toolCallListSignature(calls)
|
||||
if sig != "" && s.toolCallsDoneSigs[sig] {
|
||||
itemID := s.ensureMessageItemID()
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.messageOutputIndex(), item),
|
||||
)
|
||||
s.messageAdded = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageContentPartAdded() {
|
||||
if s.messagePartAdded {
|
||||
return
|
||||
}
|
||||
if sig != "" {
|
||||
s.toolCallsDoneSigs[sig] = true
|
||||
}
|
||||
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
s.ensureMessageItemAdded()
|
||||
s.sendEvent(
|
||||
"response.content_part.added",
|
||||
openaifmt.BuildResponsesContentPartAddedPayload(
|
||||
s.responseID,
|
||||
s.ensureMessageItemID(),
|
||||
s.messageOutputIndex(),
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": ""},
|
||||
),
|
||||
)
|
||||
s.messagePartAdded = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitTextDelta(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDoneEvents(calls)
|
||||
s.ensureMessageContentPartAdded()
|
||||
s.visibleText.WriteString(content)
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, content))
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
if !s.messageAdded {
|
||||
return
|
||||
}
|
||||
itemID := s.ensureMessageItemID()
|
||||
text := s.visibleText.String()
|
||||
if s.messagePartAdded {
|
||||
s.sendEvent(
|
||||
"response.content_part.done",
|
||||
openaifmt.BuildResponsesContentPartDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
s.messageOutputIndex(),
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": text},
|
||||
),
|
||||
)
|
||||
s.messagePartAdded = false
|
||||
}
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, s.messageOutputIndex(), item),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
@@ -65,12 +141,47 @@ func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputIndex(callIndex int) int {
|
||||
return s.functionOutputBaseIndex() + callIndex
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) {
|
||||
if strings.TrimSpace(name) != "" {
|
||||
s.functionNames[callIndex] = strings.TrimSpace(name)
|
||||
}
|
||||
if s.functionAdded[callIndex] {
|
||||
return
|
||||
}
|
||||
fnName := strings.TrimSpace(s.functionNames[callIndex])
|
||||
if fnName == "" {
|
||||
return
|
||||
}
|
||||
outputIndex := s.functionOutputIndex(callIndex)
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(callIndex)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": fnName,
|
||||
"arguments": "{}",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionAdded[callIndex] = true
|
||||
s.toolCallsEmitted = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||
for _, d := range deltas {
|
||||
s.ensureFunctionItemAdded(d.Index, d.Name)
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.functionOutputBaseIndex() + d.Index
|
||||
outputIndex := s.functionOutputIndex(d.Index)
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
@@ -86,6 +197,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
s.ensureFunctionItemAdded(idx, tc.Name)
|
||||
|
||||
outputIndex := base + idx
|
||||
if s.functionDone[outputIndex] {
|
||||
continue
|
||||
@@ -93,11 +206,25 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := string(argsBytes)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
|
||||
)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": tc.Name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionDone[outputIndex] = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,41 +259,12 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
switch typ {
|
||||
case "function_call":
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
case "tool_calls":
|
||||
tcArr, _ := m["tool_calls"].([]any)
|
||||
for i, raw := range tcArr {
|
||||
tc, _ := raw.(map[string]any)
|
||||
if tc == nil {
|
||||
continue
|
||||
}
|
||||
if i < len(ordered) {
|
||||
tc["id"] = ordered[i]
|
||||
}
|
||||
}
|
||||
if m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallListSignature(calls []util.ParsedToolCall) string {
|
||||
if len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
b.WriteString("|")
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(tc.Name))
|
||||
b.WriteString(":")
|
||||
args, _ := json.Marshal(tc.Input)
|
||||
b.Write(args)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||
@@ -30,7 +32,7 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
@@ -45,8 +47,8 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
var firstToolWrapper map[string]any
|
||||
hasFunctionCall := false
|
||||
hasLegacyWrapper := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
@@ -55,96 +57,22 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
if m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
if m["type"] == "tool_calls" && firstToolWrapper == nil {
|
||||
firstToolWrapper = m
|
||||
if m["type"] == "tool_calls" {
|
||||
hasLegacyWrapper = true
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"])
|
||||
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
|
||||
}
|
||||
if firstToolWrapper == nil {
|
||||
t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"])
|
||||
}
|
||||
toolCalls, _ := firstToolWrapper["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"])
|
||||
}
|
||||
call0, _ := toolCalls[0].(map[string]any)
|
||||
if call0["type"] != "function" {
|
||||
t.Fatalf("unexpected tool call type: %#v", call0["type"])
|
||||
}
|
||||
fn, _ := call0["function"].(map[string]any)
|
||||
if fn["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", fn["name"])
|
||||
if hasLegacyWrapper {
|
||||
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
tail := `{"tool_calls":[{"name":"read_file","input":`
|
||||
streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if strings.Count(outputText, tail) > 1 {
|
||||
t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "thought",
|
||||
})
|
||||
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil)
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.reasoning_text.done") {
|
||||
t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -163,24 +91,28 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_item.added") {
|
||||
t.Fatalf("expected response.output_item.added event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body)
|
||||
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body)
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||
}
|
||||
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
}
|
||||
if strings.TrimSpace(asString(donePayload["call_id"])) == "" {
|
||||
t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
|
||||
}
|
||||
if strings.TrimSpace(asString(donePayload["response_id"])) == "" {
|
||||
t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
|
||||
}
|
||||
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||
if doneCallID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||
@@ -191,9 +123,6 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj)
|
||||
}
|
||||
var completedCallID string
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
@@ -213,36 +142,29 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) {
|
||||
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "thought",
|
||||
})
|
||||
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected response.reasoning_text.delta event, body=%s", body)
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body)
|
||||
if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") {
|
||||
t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,121 +189,31 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
|
||||
}
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
|
||||
t.Fatalf("expected done events for both tools, got %#v", seenNames)
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
functionCallIDs := map[string]string{}
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
callID := strings.TrimSpace(asString(m["call_id"]))
|
||||
if name != "" && callID != "" {
|
||||
functionCallIDs[name] = callID
|
||||
}
|
||||
}
|
||||
if functionCallIDs["search_web"] != seenNames["search_web"] {
|
||||
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
|
||||
}
|
||||
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
|
||||
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected reasoning stream events, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seen := map[string]bool{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
seen[name] = true
|
||||
}
|
||||
if !seen["search_web"] || !seen["eval_javascript"] {
|
||||
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) {
|
||||
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -394,32 +226,76 @@ func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("我来调用工具\n") +
|
||||
sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
streamBody := sseLine("plain text only") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
break
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output)
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
toolPolicy := util.DefaultToolChoicePolicy()
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -34,6 +35,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
@@ -67,7 +69,17 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"])
|
||||
if err != nil {
|
||||
return util.StandardRequest{}, err
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
if toolPolicy.IsNone() {
|
||||
toolNames = nil
|
||||
toolPolicy.Allowed = nil
|
||||
} else {
|
||||
toolPolicy.Allowed = namesToSet(toolNames)
|
||||
}
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -78,6 +90,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
@@ -102,3 +115,212 @@ func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) {
|
||||
policy := util.DefaultToolChoicePolicy()
|
||||
declaredNames := extractDeclaredToolNames(toolsRaw)
|
||||
declaredSet := namesToSet(declaredNames)
|
||||
if len(declaredNames) > 0 {
|
||||
policy.Allowed = declaredSet
|
||||
}
|
||||
|
||||
if toolChoiceRaw == nil {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
switch v := toolChoiceRaw.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "", "auto":
|
||||
policy.Mode = util.ToolChoiceAuto
|
||||
case "none":
|
||||
policy.Mode = util.ToolChoiceNone
|
||||
policy.Allowed = nil
|
||||
case "required":
|
||||
policy.Mode = util.ToolChoiceRequired
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v)
|
||||
}
|
||||
case map[string]any:
|
||||
allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"])
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
if hasAllowedOverride {
|
||||
filtered := make([]string, 0, len(allowedOverride))
|
||||
for _, name := range allowedOverride {
|
||||
if _, ok := declaredSet[name]; !ok {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name)
|
||||
}
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
policy.Allowed = namesToSet(filtered)
|
||||
}
|
||||
|
||||
typ := strings.ToLower(strings.TrimSpace(asString(v["type"])))
|
||||
switch typ {
|
||||
case "", "auto":
|
||||
if hasFunctionSelector(v) {
|
||||
name, err := parseForcedToolName(v)
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
policy.Mode = util.ToolChoiceForced
|
||||
policy.ForcedName = name
|
||||
policy.Allowed = namesToSet([]string{name})
|
||||
} else {
|
||||
policy.Mode = util.ToolChoiceAuto
|
||||
}
|
||||
case "none":
|
||||
policy.Mode = util.ToolChoiceNone
|
||||
policy.Allowed = nil
|
||||
case "required":
|
||||
policy.Mode = util.ToolChoiceRequired
|
||||
case "function":
|
||||
name, err := parseForcedToolName(v)
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
policy.Mode = util.ToolChoiceForced
|
||||
policy.ForcedName = name
|
||||
policy.Allowed = namesToSet([]string{name})
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ)
|
||||
}
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object")
|
||||
}
|
||||
|
||||
if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced {
|
||||
if len(declaredNames) == 0 {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode)
|
||||
}
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced {
|
||||
if _, ok := declaredSet[policy.ForcedName]; !ok {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName)
|
||||
}
|
||||
}
|
||||
if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set")
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func parseForcedToolName(v map[string]any) (string, error) {
|
||||
if name := strings.TrimSpace(asString(v["name"])); name != "" {
|
||||
return name, nil
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("tool_choice function requires name")
|
||||
}
|
||||
|
||||
func parseAllowedToolNames(raw any) ([]string, bool, error) {
|
||||
if raw == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
collectName := func(v any) string {
|
||||
if name := strings.TrimSpace(asString(v)); name != "" {
|
||||
return name
|
||||
}
|
||||
if m, ok := v.(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
return name
|
||||
}
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
names := []string{}
|
||||
switch x := raw.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
name := collectName(item)
|
||||
if name == "" {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item")
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range x {
|
||||
name := strings.TrimSpace(item)
|
||||
if name == "" {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name")
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
default:
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array")
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty")
|
||||
}
|
||||
return names, true, nil
|
||||
}
|
||||
|
||||
func hasFunctionSelector(v map[string]any) bool {
|
||||
if strings.TrimSpace(asString(v["name"])) != "" {
|
||||
return true
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
return strings.TrimSpace(asString(fn["name"])) != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractDeclaredToolNames(toolsRaw any) []string {
|
||||
tools, ok := toolsRaw.([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(tools))
|
||||
seen := map[string]struct{}{}
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name := strings.TrimSpace(asString(fn["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func namesToSet(names []string) map[string]struct{} {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out[trimmed] = struct{}{}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
|
||||
@@ -58,3 +59,95 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": "required",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceRequired {
|
||||
t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" {
|
||||
t.Fatalf("unexpected tool names: %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "read_file",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"name": "read_file",
|
||||
},
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceForced {
|
||||
t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if n.ToolChoice.ForcedName != "read_file" {
|
||||
t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName)
|
||||
}
|
||||
if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" {
|
||||
t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"name": "read_file",
|
||||
},
|
||||
}
|
||||
if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil {
|
||||
t.Fatalf("expected forced undeclared tool to fail")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +27,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
formatted := util.FormatOpenAIToolCalls(detected)
|
||||
output = append(output, toResponsesFunctionCallItems(formatted)...)
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": formatted,
|
||||
})
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
@@ -70,32 +65,23 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
|
||||
func toResponsesFunctionCallItems(toolCalls []util.ParsedToolCall) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
callID, _ := tc["id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
name := ""
|
||||
args := "{}"
|
||||
if fn, ok := tc["function"].(map[string]any); ok {
|
||||
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
|
||||
name = n
|
||||
}
|
||||
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
|
||||
args = a
|
||||
}
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := normalizeJSONString(string(argsBytes))
|
||||
out = append(out, map[string]any{
|
||||
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(args),
|
||||
"call_id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"name": tc.Name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
@@ -11,6 +13,52 @@ func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesOutputItemAddedPayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_item.added",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"output_index": outputIndex,
|
||||
"item_id": itemID,
|
||||
"item": item,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesOutputItemDonePayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_item.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"output_index": outputIndex,
|
||||
"item_id": itemID,
|
||||
"item": item,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesContentPartAddedPayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.content_part.added",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"part": part,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.content_part.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"part": part,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
@@ -29,48 +77,6 @@ func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]an
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
@@ -96,6 +102,27 @@ func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, o
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFailedPayload(responseID, model, message, code string) map[string]any {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
code = "api_error"
|
||||
}
|
||||
return map[string]any{
|
||||
"type": "response.failed",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "failed",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
responseID, _ := response["id"].(string)
|
||||
return map[string]any{
|
||||
|
||||
@@ -21,8 +21,8 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
}
|
||||
|
||||
first, _ := output[0].(map[string]any)
|
||||
@@ -32,35 +32,10 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
if first["call_id"] == "" {
|
||||
t.Fatalf("expected function_call item to have call_id, got %#v", first)
|
||||
}
|
||||
second, _ := output[1].(map[string]any)
|
||||
if second["type"] != "tool_calls" {
|
||||
t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
|
||||
if first["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", first["name"])
|
||||
}
|
||||
var toolCalls []map[string]any
|
||||
switch v := second["tool_calls"].(type) {
|
||||
case []map[string]any:
|
||||
toolCalls = v
|
||||
case []any:
|
||||
toolCalls = make([]map[string]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil {
|
||||
toolCalls = append(toolCalls, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
|
||||
}
|
||||
tc := toolCalls[0]
|
||||
if tc["type"] != "function" || tc["id"] == "" {
|
||||
t.Fatalf("unexpected tool call shape: %#v", tc)
|
||||
}
|
||||
fn, _ := tc["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", fn["name"])
|
||||
}
|
||||
argsRaw, _ := fn["arguments"].(string)
|
||||
argsRaw, _ := first["arguments"].(string)
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
|
||||
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
|
||||
@@ -86,8 +61,8 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
@@ -163,8 +138,8 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
)
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 3 {
|
||||
t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "reasoning" {
|
||||
@@ -174,8 +149,4 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
if second["type"] != "function_call" {
|
||||
t.Fatalf("expected second output function_call, got %#v", second["type"])
|
||||
}
|
||||
third, _ := output[2].(map[string]any)
|
||||
if third["type"] != "tool_calls" {
|
||||
t.Fatalf("expected third output tool_calls, got %#v", third["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,12 +8,48 @@ type StandardRequest struct {
|
||||
Messages []any
|
||||
FinalPrompt string
|
||||
ToolNames []string
|
||||
ToolChoice ToolChoicePolicy
|
||||
Stream bool
|
||||
Thinking bool
|
||||
Search bool
|
||||
PassThrough map[string]any
|
||||
}
|
||||
|
||||
type ToolChoiceMode string
|
||||
|
||||
const (
|
||||
ToolChoiceAuto ToolChoiceMode = "auto"
|
||||
ToolChoiceNone ToolChoiceMode = "none"
|
||||
ToolChoiceRequired ToolChoiceMode = "required"
|
||||
ToolChoiceForced ToolChoiceMode = "forced"
|
||||
)
|
||||
|
||||
type ToolChoicePolicy struct {
|
||||
Mode ToolChoiceMode
|
||||
ForcedName string
|
||||
Allowed map[string]struct{}
|
||||
}
|
||||
|
||||
func DefaultToolChoicePolicy() ToolChoicePolicy {
|
||||
return ToolChoicePolicy{Mode: ToolChoiceAuto}
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) IsNone() bool {
|
||||
return p.Mode == ToolChoiceNone
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) IsRequired() bool {
|
||||
return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) Allows(name string) bool {
|
||||
if len(p.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := p.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
|
||||
@@ -10,38 +10,62 @@ type ParsedToolCall struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
|
||||
type ToolCallParseResult struct {
|
||||
Calls []ParsedToolCall
|
||||
SawToolCallSyntax bool
|
||||
RejectedByPolicy bool
|
||||
RejectedToolNames []string
|
||||
}
|
||||
|
||||
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
|
||||
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
|
||||
result := ToolCallParseResult{}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
text = stripFencedCodeBlocks(text)
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
for _, candidate := range candidates {
|
||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
||||
parsed = tc
|
||||
result.SawToolCallSyntax = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
|
||||
result := ToolCallParseResult{}
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
|
||||
candidates := []string{trimmed}
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
@@ -52,24 +76,31 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
|
||||
continue
|
||||
}
|
||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
result.SawToolCallSyntax = true
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
||||
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
|
||||
func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) {
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(parsed))
|
||||
rejectedSet := map[string]struct{}{}
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[tc.Name]; !ok {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -78,21 +109,11 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
// If the model clearly emitted tool_calls JSON but all names are outside the
|
||||
// declared set, keep the parsed calls as a fallback so upper layers can still
|
||||
// intercept structured tool output instead of leaking raw JSON to users.
|
||||
if len(out) == 0 && len(parsed) > 0 {
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if tc.Input == nil {
|
||||
tc.Input = map[string]any{}
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
rejected := make([]string, 0, len(rejectedSet))
|
||||
for name := range rejectedSet {
|
||||
rejected = append(rejected, name)
|
||||
}
|
||||
return out
|
||||
return out, rejected
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
|
||||
@@ -38,14 +38,25 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) {
|
||||
func TestParseToolCallsRejectsUnknownToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected fallback 1 call, got %d", len(calls))
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected unknown tool to be rejected, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "unknown" {
|
||||
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||
}
|
||||
|
||||
func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
res := ParseToolCallsDetailed(text, []string{"search"})
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||
}
|
||||
if !res.RejectedByPolicy {
|
||||
t.Fatalf("expected RejectedByPolicy=true, got %#v", res)
|
||||
}
|
||||
if len(res.Calls) != 0 {
|
||||
t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user