mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
fix: Ensure incomplete tool call items are properly closed and required tool choice failures are correctly handled for malformed payloads.
This commit is contained in:
@@ -39,6 +39,7 @@ type responsesStreamRuntime struct {
|
||||
streamToolCallIDs map[int]string
|
||||
functionItemIDs map[int]string
|
||||
functionOutputIDs map[int]int
|
||||
functionArgs map[int]string
|
||||
functionDone map[int]bool
|
||||
functionAdded map[int]bool
|
||||
functionNames map[int]string
|
||||
@@ -84,6 +85,7 @@ func newResponsesStreamRuntime(
|
||||
streamToolCallIDs: map[int]string{},
|
||||
functionItemIDs: map[int]string{},
|
||||
functionOutputIDs: map[int]int{},
|
||||
functionArgs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
functionAdded: map[int]bool{},
|
||||
functionNames: map[int]string{},
|
||||
@@ -120,7 +122,7 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
|
||||
s.closeMessageItem()
|
||||
|
||||
if s.toolChoice.IsRequired() && !s.hasFunctionCallDone() {
|
||||
if s.toolChoice.IsRequired() && len(detected) == 0 {
|
||||
s.failed = true
|
||||
message := "tool_choice requires at least one valid tool call."
|
||||
failedResp := map[string]any{
|
||||
@@ -145,6 +147,7 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.sendDone()
|
||||
return
|
||||
}
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.persistResponse != nil {
|
||||
|
||||
@@ -188,6 +188,7 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDe
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
s.functionArgs[d.Index] += d.Arguments
|
||||
outputIndex := s.ensureFunctionOutputIndex(d.Index)
|
||||
itemID := s.ensureFunctionItemID(d.Index)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
@@ -212,6 +213,7 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := string(argsBytes)
|
||||
s.functionArgs[idx] = args
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
|
||||
@@ -233,6 +235,54 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) closeIncompleteFunctionItems() {
|
||||
if len(s.functionAdded) == 0 {
|
||||
return
|
||||
}
|
||||
indices := make([]int, 0, len(s.functionAdded))
|
||||
for idx, added := range s.functionAdded {
|
||||
if !added || s.functionDone[idx] {
|
||||
continue
|
||||
}
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
if len(indices) == 0 {
|
||||
return
|
||||
}
|
||||
sort.Ints(indices)
|
||||
for _, idx := range indices {
|
||||
name := strings.TrimSpace(s.functionNames[idx])
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
args := strings.TrimSpace(s.functionArgs[idx])
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
outputIndex := s.ensureFunctionOutputIndex(idx)
|
||||
itemID := s.ensureFunctionItemID(idx)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, name, args),
|
||||
)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionDone[idx] = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any {
|
||||
type indexedItem struct {
|
||||
index int
|
||||
|
||||
@@ -360,6 +360,42 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected runtime to close function output item, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -394,6 +430,40 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
Reference in New Issue
Block a user