Fix SSE keep-alive passthrough, content-filter stop, and usage token propagation

This commit is contained in:
CJACK.
2026-04-02 23:58:36 +08:00
parent 443fa4ad8e
commit e958bf7e40
16 changed files with 223 additions and 9 deletions

View File

@@ -10,8 +10,9 @@ import (
// CollectResult holds the aggregated text and thinking content from a
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
type CollectResult struct {
Text string
Thinking string
Text string
Thinking string
OutputTokens int
}
// CollectStream fully consumes a DeepSeek SSE response and separates
@@ -26,6 +27,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
}
text := strings.Builder{}
thinking := strings.Builder{}
outputTokens := 0
currentType := "text"
if thinkingEnabled {
currentType = "thinking"
@@ -37,8 +39,14 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
return true
}
if result.Stop {
if result.OutputTokens > 0 {
outputTokens = result.OutputTokens
}
return false
}
if result.OutputTokens > 0 {
outputTokens = result.OutputTokens
}
for _, p := range result.Parts {
if p.Type == "thinking" {
thinking.WriteString(p.Text)
@@ -48,5 +56,5 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
}
return true
})
return CollectResult{Text: text.String(), Thinking: thinking.String()}
return CollectResult{Text: text.String(), Thinking: thinking.String(), OutputTokens: outputTokens}
}

View File

@@ -138,3 +138,15 @@ func TestCollectStreamStatusFinished(t *testing.T) {
t.Fatalf("expected 'Hello', got %q", result.Text)
}
}
func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" +
"data: {\"p\":\"response/status\",\"v\":\"CONTENT_FILTER\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"blocked\"}\n",
)
result := CollectStream(resp, false, false)
if result.Text != "safe" {
t.Fatalf("expected stream to stop before blocked tail, got %q", result.Text)
}
}

View File

@@ -10,6 +10,7 @@ type LineResult struct {
ErrorMessage string
Parts []ContentPart
NextType string
OutputTokens int
}
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
@@ -39,6 +40,16 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
NextType: currentType,
}
}
if hasContentFilterStatus(chunk) {
return LineResult{
Parsed: true,
Stop: true,
ContentFilter: true,
ErrorMessage: "content filtered by upstream",
NextType: currentType,
OutputTokens: extractAccumulatedTokenUsage(chunk),
}
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
parts = filterLeakedContentFilterParts(parts)
return LineResult{
@@ -46,5 +57,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
Stop: finished,
Parts: parts,
NextType: nextType,
OutputTokens: extractAccumulatedTokenUsage(chunk),
}
}

View File

@@ -26,6 +26,20 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
}
}
func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/status","v":"CONTENT_FILTER"}`), false, "text")
if !res.Parsed || !res.Stop || !res.ContentFilter {
t.Fatalf("expected status-based content-filter stop result: %#v", res)
}
}
func TestParseDeepSeekContentLineCapturesAccumulatedTokenUsage(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":1383},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text")
if res.OutputTokens != 1383 {
t.Fatalf("expected output token usage 1383, got %d", res.OutputTokens)
}
}
func TestParseDeepSeekContentLineContent(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
if !res.Parsed || res.Stop {

View File

@@ -3,6 +3,7 @@ package sse
import (
"bytes"
"encoding/json"
"math"
"strings"
"ds2api/internal/deepseek"
@@ -287,3 +288,86 @@ func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bo
func IsCitation(text string) bool {
return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:"))
}
func hasContentFilterStatus(chunk map[string]any) bool {
return hasContentFilterValue(chunk)
}
func hasContentFilterValue(v any) bool {
switch x := v.(type) {
case string:
return strings.EqualFold(strings.TrimSpace(x), "content_filter")
case []any:
for _, item := range x {
if hasContentFilterValue(item) {
return true
}
}
case map[string]any:
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "status") {
if s, _ := x["v"].(string); strings.EqualFold(strings.TrimSpace(s), "content_filter") {
return true
}
}
for _, vv := range x {
if hasContentFilterValue(vv) {
return true
}
}
}
return false
}
func extractAccumulatedTokenUsage(chunk map[string]any) int {
return findAccumulatedTokenUsage(chunk)
}
func findAccumulatedTokenUsage(v any) int {
switch x := v.(type) {
case map[string]any:
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") {
if n, ok := toInt(x["v"]); ok && n > 0 {
return n
}
}
if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 {
return n
}
for _, vv := range x {
if n := findAccumulatedTokenUsage(vv); n > 0 {
return n
}
}
case []any:
for _, item := range x {
if n := findAccumulatedTokenUsage(item); n > 0 {
return n
}
}
}
return 0
}
func toInt(v any) (int, bool) {
switch x := v.(type) {
case int:
return x, true
case int32:
return int(x), true
case int64:
return int(x), true
case float64:
if math.IsNaN(x) || math.IsInf(x, 0) {
return 0, false
}
return int(x), true
case json.Number:
i, err := x.Int64()
if err != nil {
return 0, false
}
return int(i), true
default:
return 0, false
}
}