diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index d59ea66..563b0f2 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -190,7 +190,10 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } - if parsed.ContentFilter || parsed.ErrorMessage != "" { + if parsed.ContentFilter { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested} + } + if parsed.ErrorMessage != "" { return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")} } if parsed.Stop { diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go index c76d881..2a3584b 100644 --- a/internal/adapter/openai/stream_status_test.go +++ b/internal/adapter/openai/stream_status_test.go @@ -183,3 +183,53 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) { t.Fatalf("expected function_call output item, got %#v", output) } } + +func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse( + `data: {"p":"response/content","v":"合法前缀"}`, + `data: {"p":"response/status","v":"CONTENT_FILTER","accumulated_token_usage":77}`, + `data: {"p":"response/content","v":"CONTENT_FILTER你好,这个问题我暂时无法回答,让我们换个话题再聊聊吧。"}`, + )}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + if strings.Contains(rec.Body.String(), "这个问题我暂时无法回答") { + t.Fatalf("expected leaked content-filter suffix to be hidden, body=%s", rec.Body.String()) + } + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if len(frames) == 0 { + t.Fatalf("expected at least one json frame, body=%s", rec.Body.String()) + } + last := frames[len(frames)-1] + choices, _ := last["choices"].([]any) + if len(choices) != 1 { + t.Fatalf("expected one choice in final frame, got %#v", last) + } + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop for content-filter upstream stop, got %#v", choice["finish_reason"]) + } +} diff --git a/internal/sse/line.go b/internal/sse/line.go index b91edc7..1d9ddae 100644 --- a/internal/sse/line.go +++ b/internal/sse/line.go @@ -36,8 +36,8 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Parsed: true, Stop: true, ContentFilter: true, - ErrorMessage: "content filtered by upstream", NextType: currentType, + OutputTokens: extractAccumulatedTokenUsage(chunk), } } if hasContentFilterStatus(chunk) { @@ -45,7 +45,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Parsed: true, Stop: true, ContentFilter: true, - ErrorMessage: "content filtered by upstream", NextType: currentType, OutputTokens: extractAccumulatedTokenUsage(chunk), } diff --git a/internal/sse/line_edge_test.go b/internal/sse/line_edge_test.go index 2ae53a6..4d507fc 100644 --- a/internal/sse/line_edge_test.go +++ b/internal/sse/line_edge_test.go @@ -40,8 +40,8 @@ func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) { if !res.ContentFilter { t.Fatal("expected content filter flag") } - if res.ErrorMessage == "" { - t.Fatal("expected error message on content filter") + if res.ErrorMessage != "" { + t.Fatalf("expected empty error message on content filter, got %q", res.ErrorMessage) } } diff --git a/internal/sse/line_test.go b/internal/sse/line_test.go index a226034..7f2baa6 100644 --- a/internal/sse/line_test.go +++ b/internal/sse/line_test.go @@ -26,6 +26,19 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) { } } +func TestParseDeepSeekContentLineContentFilterCodeIncludesOutputTokens(t *testing.T) { + res := ParseDeepSeekContentLine( + []byte(`data: {"code":"content_filter","accumulated_token_usage":99}`), + false, "text", + ) + if !res.Parsed || !res.Stop || !res.ContentFilter { + t.Fatalf("expected content-filter stop result: %#v", res) + } + if res.OutputTokens != 99 { + t.Fatalf("expected output token usage 99, got %d", res.OutputTokens) + } +} + func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) { res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/status","v":"CONTENT_FILTER"}`), false, "text") if !res.Parsed || !res.Stop || !res.ContentFilter { @@ -79,3 +92,13 @@ func TestParseDeepSeekContentLineTrimsFromContentFilterKeyword(t *testing.T) { t.Fatalf("unexpected parts after filter: %#v", res.Parts) } } + +func TestParseDeepSeekContentLineContentTextEqualContentFilterDoesNotStop(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"content_filter"}`), false, "text") + if !res.Parsed { + t.Fatalf("expected parsed result: %#v", res) + } + if res.Stop || res.ContentFilter { + t.Fatalf("did not expect content-filter stop for content text: %#v", res) + } +} diff --git a/internal/sse/parser.go b/internal/sse/parser.go index 725ac1f..1074a34 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -290,16 +290,17 @@ func IsCitation(text string) bool { } func hasContentFilterStatus(chunk map[string]any) bool { - return hasContentFilterValue(chunk) + if code, _ := chunk["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") { + return true + } + return hasContentFilterStatusValue(chunk) } -func hasContentFilterValue(v any) bool { +func hasContentFilterStatusValue(v any) bool { switch x := v.(type) { - case string: - return strings.EqualFold(strings.TrimSpace(x), "content_filter") case []any: for _, item := range x { - if hasContentFilterValue(item) { + if hasContentFilterStatusValue(item) { return true } } @@ -309,8 +310,11 @@ func hasContentFilterValue(v any) bool { return true } } + if code, _ := x["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") { + return true + } for _, vv := range x { - if hasContentFilterValue(vv) { + if hasContentFilterStatusValue(vv) { return true } }