fix: Ensure incomplete tool call items are properly closed and required tool choice failures are correctly handled for malformed payloads.

This commit is contained in:
CJACK
2026-02-22 21:27:42 +08:00
parent a9403c5392
commit 6c318f1910
3 changed files with 124 additions and 1 deletions

View File

@@ -39,6 +39,7 @@ type responsesStreamRuntime struct {
streamToolCallIDs map[int]string
functionItemIDs map[int]string
functionOutputIDs map[int]int
functionArgs map[int]string
functionDone map[int]bool
functionAdded map[int]bool
functionNames map[int]string
@@ -84,6 +85,7 @@ func newResponsesStreamRuntime(
streamToolCallIDs: map[int]string{},
functionItemIDs: map[int]string{},
functionOutputIDs: map[int]int{},
functionArgs: map[int]string{},
functionDone: map[int]bool{},
functionAdded: map[int]bool{},
functionNames: map[int]string{},
@@ -120,7 +122,7 @@ func (s *responsesStreamRuntime) finalize() {
s.closeMessageItem()
if s.toolChoice.IsRequired() && !s.hasFunctionCallDone() {
if s.toolChoice.IsRequired() && len(detected) == 0 {
s.failed = true
message := "tool_choice requires at least one valid tool call."
failedResp := map[string]any{
@@ -145,6 +147,7 @@ func (s *responsesStreamRuntime) finalize() {
s.sendDone()
return
}
s.closeIncompleteFunctionItems()
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
if s.persistResponse != nil {

View File

@@ -188,6 +188,7 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDe
if strings.TrimSpace(d.Arguments) == "" {
continue
}
s.functionArgs[d.Index] += d.Arguments
outputIndex := s.ensureFunctionOutputIndex(d.Index)
itemID := s.ensureFunctionItemID(d.Index)
callID := s.ensureToolCallID(d.Index)
@@ -212,6 +213,7 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
callID := s.ensureToolCallID(idx)
argsBytes, _ := json.Marshal(tc.Input)
args := string(argsBytes)
s.functionArgs[idx] = args
s.sendEvent(
"response.function_call_arguments.done",
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
@@ -233,6 +235,54 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
}
}
func (s *responsesStreamRuntime) closeIncompleteFunctionItems() {
if len(s.functionAdded) == 0 {
return
}
indices := make([]int, 0, len(s.functionAdded))
for idx, added := range s.functionAdded {
if !added || s.functionDone[idx] {
continue
}
indices = append(indices, idx)
}
if len(indices) == 0 {
return
}
sort.Ints(indices)
for _, idx := range indices {
name := strings.TrimSpace(s.functionNames[idx])
if name == "" {
continue
}
args := strings.TrimSpace(s.functionArgs[idx])
if args == "" {
args = "{}"
}
outputIndex := s.ensureFunctionOutputIndex(idx)
itemID := s.ensureFunctionItemID(idx)
callID := s.ensureToolCallID(idx)
s.sendEvent(
"response.function_call_arguments.done",
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, name, args),
)
item := map[string]any{
"id": itemID,
"type": "function_call",
"call_id": callID,
"name": name,
"arguments": args,
"status": "completed",
}
s.sendEvent(
"response.output_item.done",
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
)
s.functionDone[idx] = true
s.toolCallsDoneEmitted = true
}
}
func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any {
type indexedItem struct {
index int

View File

@@ -360,6 +360,42 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
}
}
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
}
if !strings.Contains(body, "event: response.output_item.done") {
t.Fatalf("expected runtime to close function output item, body=%s", body)
}
if !strings.Contains(body, "event: response.completed") {
t.Fatalf("expected response.completed event, body=%s", body)
}
}
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -394,6 +430,40 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
}
}
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
t.Fatalf("expected response.failed event, body=%s", body)
}
if strings.Contains(body, "event: response.completed") {
t.Fatalf("did not expect response.completed, body=%s", body)
}
}
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)