feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.

This commit is contained in:
CJACK
2026-02-22 19:33:52 +08:00
parent 312728c8b6
commit ae7dce0b32
26 changed files with 1109 additions and 501 deletions

View File

@@ -27,12 +27,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
"text": finalThinking,
})
}
formatted := util.FormatOpenAIToolCalls(detected)
output = append(output, toResponsesFunctionCallItems(formatted)...)
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": formatted,
})
output = append(output, toResponsesFunctionCallItems(detected)...)
} else {
content := make([]any, 0, 2)
if finalThinking != "" {
@@ -70,32 +65,23 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
}
}
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
func toResponsesFunctionCallItems(toolCalls []util.ParsedToolCall) []any {
if len(toolCalls) == 0 {
return nil
}
out := make([]any, 0, len(toolCalls))
for _, tc := range toolCalls {
callID, _ := tc["id"].(string)
if strings.TrimSpace(callID) == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
name := ""
args := "{}"
if fn, ok := tc["function"].(map[string]any); ok {
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
name = n
}
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
args = a
}
if strings.TrimSpace(tc.Name) == "" {
continue
}
argsBytes, _ := json.Marshal(tc.Input)
args := normalizeJSONString(string(argsBytes))
out = append(out, map[string]any{
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function_call",
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(args),
"call_id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"name": tc.Name,
"arguments": args,
"status": "completed",
})
}

View File

@@ -1,5 +1,7 @@
package openai
import "strings"
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
"type": "response.created",
@@ -11,6 +13,52 @@ func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
}
}
func BuildResponsesOutputItemAddedPayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
return map[string]any{
"type": "response.output_item.added",
"id": responseID,
"response_id": responseID,
"output_index": outputIndex,
"item_id": itemID,
"item": item,
}
}
func BuildResponsesOutputItemDonePayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
return map[string]any{
"type": "response.output_item.done",
"id": responseID,
"response_id": responseID,
"output_index": outputIndex,
"item_id": itemID,
"item": item,
}
}
func BuildResponsesContentPartAddedPayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
return map[string]any{
"type": "response.content_part.added",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"part": part,
}
}
func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
return map[string]any{
"type": "response.content_part.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"part": part,
}
}
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.output_text.delta",
@@ -29,48 +77,6 @@ func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]an
}
}
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"delta": delta,
}
}
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"text": text,
}
}
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.delta",
@@ -96,6 +102,27 @@ func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, o
}
}
func BuildResponsesFailedPayload(responseID, model, message, code string) map[string]any {
code = strings.TrimSpace(code)
if code == "" {
code = "api_error"
}
return map[string]any{
"type": "response.failed",
"id": responseID,
"response_id": responseID,
"object": "response",
"model": model,
"status": "failed",
"error": map[string]any{
"message": message,
"type": "invalid_request_error",
"code": code,
"param": nil,
},
}
}
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
responseID, _ := response["id"].(string)
return map[string]any{

View File

@@ -21,8 +21,8 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
}
output, _ := obj["output"].([]any)
if len(output) != 2 {
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
@@ -32,35 +32,10 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
if first["call_id"] == "" {
t.Fatalf("expected function_call item to have call_id, got %#v", first)
}
second, _ := output[1].(map[string]any)
if second["type"] != "tool_calls" {
t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
if first["name"] != "search" {
t.Fatalf("unexpected function name: %#v", first["name"])
}
var toolCalls []map[string]any
switch v := second["tool_calls"].(type) {
case []map[string]any:
toolCalls = v
case []any:
toolCalls = make([]map[string]any, 0, len(v))
for _, item := range v {
m, _ := item.(map[string]any)
if m != nil {
toolCalls = append(toolCalls, m)
}
}
}
if len(toolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
}
tc := toolCalls[0]
if tc["type"] != "function" || tc["id"] == "" {
t.Fatalf("unexpected tool call shape: %#v", tc)
}
fn, _ := tc["function"].(map[string]any)
if fn["name"] != "search" {
t.Fatalf("unexpected function name: %#v", fn["name"])
}
argsRaw, _ := fn["arguments"].(string)
argsRaw, _ := first["arguments"].(string)
var args map[string]any
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
@@ -86,8 +61,8 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
}
output, _ := obj["output"].([]any)
if len(output) != 2 {
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
@@ -163,8 +138,8 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
)
output, _ := obj["output"].([]any)
if len(output) != 3 {
t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
if len(output) != 2 {
t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "reasoning" {
@@ -174,8 +149,4 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
if second["type"] != "function_call" {
t.Fatalf("expected second output function_call, got %#v", second["type"])
}
third, _ := output[2].(map[string]any)
if third["type"] != "tool_calls" {
t.Fatalf("expected third output tool_calls, got %#v", third["type"])
}
}