diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 81da92d..e4b1de8 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -113,15 +113,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) - textParsed := util.ParseToolCallsDetailed(result.Text, toolNames) - thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames) + textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames) logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text") - logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking") callCount := len(textParsed.Calls) - if callCount == 0 { - callCount = len(thinkingParsed.Calls) - } if toolChoice.IsRequired() && callCount == 0 { writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") return diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index 02303d0..e8ec6df 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -102,16 +102,11 @@ func (s *responsesStreamRuntime) finalize() { if s.bufferToolContent { s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) - s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) } - textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames) - thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames) + textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames) detected := textParsed.Calls - if len(detected) == 0 { - detected = thinkingParsed.Calls - } - s.logToolPolicyRejections(textParsed, thinkingParsed) + s.logToolPolicyRejections(textParsed) if len(detected) > 0 { s.toolCallsEmitted = true @@ -157,7 +152,7 @@ func (s *responsesStreamRuntime) finalize() { s.sendDone() } -func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) { +func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) { logRejected := func(parsed util.ToolCallParseResult, channel string) { rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) if !parsed.RejectedByPolicy || len(rejected) == 0 { @@ -172,7 +167,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar ) } logRejected(textParsed, "text") - logRejected(thinkingParsed, "thinking") } func (s *responsesStreamRuntime) hasFunctionCallDone() bool { @@ -207,9 +201,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa } s.thinking.WriteString(p.Text) s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) - if s.bufferToolContent { - s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) - } continue } diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 29cb2a1..90ade96 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -263,7 +263,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) { } } -func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) { +func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -288,23 +288,12 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added") - if len(addedPayloads) < 2 { - t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String()) + if len(addedPayloads) != 1 { + t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String()) } - - indexes := map[int]struct{}{} - typeByIndex := map[int]string{} - addedIDs := map[string]string{} - for _, payload := range addedPayloads { - item, _ := payload["item"].(map[string]any) - itemType := strings.TrimSpace(asString(item["type"])) - outputIndex := int(asFloat(payload["output_index"])) - if _, exists := indexes[outputIndex]; exists { - t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload) - } - indexes[outputIndex] = struct{}{} - typeByIndex[outputIndex] = itemType - addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"])) + item, _ := addedPayloads[0]["item"].(map[string]any) + if asString(item["type"]) != "message" { + t.Fatalf("expected only message output item in strict mode, got %#v", item) } completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") @@ -313,21 +302,15 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes } responseObj, _ := completedPayload["response"].(map[string]any) output, _ := responseObj["output"].([]any) - found := map[string]bool{} for _, item := range output { m, _ := item.(map[string]any) - itemType := strings.TrimSpace(asString(m["type"])) - itemID := strings.TrimSpace(asString(m["id"])) - if itemType == "" || itemID == "" { + if m == nil { continue } - if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID { - found[itemType] = true + if asString(m["type"]) == "function_call" { + t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output) } } - if !found["message"] || !found["function_call"] { - t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output) - } } func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { @@ -424,6 +407,42 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) { } } +func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, value string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": value, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + + sseLine("response/content", "plain text only") + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "") + body := rec.Body.String() + if !strings.Contains(body, "event: response.failed") { + t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body) + } + if strings.Contains(body, "event: response.completed") { + t.Fatalf("did not expect response.completed after failure, body=%s", body) + } +} + func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) @@ -510,6 +529,33 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) { } } +func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" + + `data: {"p":"response/content","v":"plain text only"}` + "\n" + + `data: [DONE]` + "\n", + )), + } + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "") + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String()) + } + out := decodeJSONBody(t, rec.Body.String()) + errObj, _ := out["error"].(map[string]any) + if asString(errObj["code"]) != "tool_choice_violation" { + t.Fatalf("expected code=tool_choice_violation, got %#v", out) + } +} + func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { h := &Handler{} rec := httptest.NewRecorder()