mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 16:35:27 +08:00
Fix SSE keep-alive passthrough, content-filter stop, and usage token propagation
This commit is contained in:
@@ -90,6 +90,11 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = result.OutputTokens
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ type claudeStreamRuntime struct {
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
@@ -66,6 +67,9 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
|
||||
@@ -108,6 +108,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
if s.outputTokens > 0 {
|
||||
outputTokens = s.outputTokens
|
||||
}
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
|
||||
@@ -174,12 +174,12 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames, result.OutputTokens))
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
@@ -196,10 +196,14 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
if outputTokens > 0 {
|
||||
completionTokens = outputTokens
|
||||
reasoningTokens = 0
|
||||
}
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
|
||||
@@ -64,6 +64,7 @@ type geminiStreamRuntime struct {
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
@@ -103,6 +104,9 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
@@ -176,6 +180,6 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -165,12 +166,19 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
if s.outputTokens > 0 {
|
||||
usage["completion_tokens"] = s.outputTokens
|
||||
if prompt, ok := usage["prompt_tokens"].(int); ok {
|
||||
usage["total_tokens"] = prompt + s.outputTokens
|
||||
}
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)},
|
||||
openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText),
|
||||
usage,
|
||||
))
|
||||
s.sendDone()
|
||||
}
|
||||
@@ -179,6 +187,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||
}
|
||||
|
||||
@@ -107,6 +107,14 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
finalThinking := result.Thinking
|
||||
finalText := sanitizeLeakedOutput(result.Text)
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||
usage["completion_tokens"] = result.OutputTokens
|
||||
if prompt, ok := usage["prompt_tokens"].(int); ok {
|
||||
usage["total_tokens"] = prompt + result.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
|
||||
@@ -124,6 +124,14 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := responseObj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = result.OutputTokens
|
||||
if input, ok := usage["input_tokens"].(int); ok {
|
||||
usage["total_tokens"] = input + result.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ type responsesStreamRuntime struct {
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
outputTokens int
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -144,6 +145,14 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.outputTokens > 0 {
|
||||
if usage, ok := obj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = s.outputTokens
|
||||
if input, ok := usage["input_tokens"].(int); ok {
|
||||
usage["total_tokens"] = input + s.outputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -172,6 +181,9 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,9 @@ import (
|
||||
// CollectResult holds the aggregated text and thinking content from a
|
||||
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
|
||||
type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
Text string
|
||||
Thinking string
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
// CollectStream fully consumes a DeepSeek SSE response and separates
|
||||
@@ -26,6 +27,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
outputTokens := 0
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
@@ -37,8 +39,14 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
return true
|
||||
}
|
||||
if result.Stop {
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
return false
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
for _, p := range result.Parts {
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
@@ -48,5 +56,5 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
return true
|
||||
})
|
||||
return CollectResult{Text: text.String(), Thinking: thinking.String()}
|
||||
return CollectResult{Text: text.String(), Thinking: thinking.String(), OutputTokens: outputTokens}
|
||||
}
|
||||
|
||||
@@ -138,3 +138,15 @@ func TestCollectStreamStatusFinished(t *testing.T) {
|
||||
t.Fatalf("expected 'Hello', got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" +
|
||||
"data: {\"p\":\"response/status\",\"v\":\"CONTENT_FILTER\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"blocked\"}\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "safe" {
|
||||
t.Fatalf("expected stream to stop before blocked tail, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type LineResult struct {
|
||||
ErrorMessage string
|
||||
Parts []ContentPart
|
||||
NextType string
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
|
||||
@@ -39,6 +40,16 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
NextType: currentType,
|
||||
}
|
||||
}
|
||||
if hasContentFilterStatus(chunk) {
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
ErrorMessage: "content filtered by upstream",
|
||||
NextType: currentType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
parts = filterLeakedContentFilterParts(parts)
|
||||
return LineResult{
|
||||
@@ -46,5 +57,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,20 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/status","v":"CONTENT_FILTER"}`), false, "text")
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected status-based content-filter stop result: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineCapturesAccumulatedTokenUsage(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":1383},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text")
|
||||
if res.OutputTokens != 1383 {
|
||||
t.Fatalf("expected output token usage 1383, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
|
||||
@@ -3,6 +3,7 @@ package sse
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
@@ -287,3 +288,86 @@ func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bo
|
||||
func IsCitation(text string) bool {
|
||||
return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:"))
|
||||
}
|
||||
|
||||
func hasContentFilterStatus(chunk map[string]any) bool {
|
||||
return hasContentFilterValue(chunk)
|
||||
}
|
||||
|
||||
func hasContentFilterValue(v any) bool {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return strings.EqualFold(strings.TrimSpace(x), "content_filter")
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if hasContentFilterValue(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "status") {
|
||||
if s, _ := x["v"].(string); strings.EqualFold(strings.TrimSpace(s), "content_filter") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, vv := range x {
|
||||
if hasContentFilterValue(vv) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAccumulatedTokenUsage(chunk map[string]any) int {
|
||||
return findAccumulatedTokenUsage(chunk)
|
||||
}
|
||||
|
||||
func findAccumulatedTokenUsage(v any) int {
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") {
|
||||
if n, ok := toInt(x["v"]); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
for _, vv := range x {
|
||||
if n := findAccumulatedTokenUsage(vv); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if n := findAccumulatedTokenUsage(item); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func toInt(v any) (int, bool) {
|
||||
switch x := v.(type) {
|
||||
case int:
|
||||
return x, true
|
||||
case int32:
|
||||
return int(x), true
|
||||
case int64:
|
||||
return int(x), true
|
||||
case float64:
|
||||
if math.IsNaN(x) || math.IsInf(x, 0) {
|
||||
return 0, false
|
||||
}
|
||||
return int(x), true
|
||||
case json.Number:
|
||||
i, err := x.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(i), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,18 @@ func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) {
|
||||
if len(trimmed) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(trimmed, []byte(":")) {
|
||||
if _, err := w.dst.Write(trimmed); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if _, err := w.dst.Write([]byte("\n\n")); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -42,3 +42,16 @@ func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) {
|
||||
t.Fatalf("expected gemini stream payload, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatGemini, "gemini-2.5-pro", []byte(`{}`), []byte(`{}`))
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, ": keep-alive\n\n") {
|
||||
t.Fatalf("expected keep-alive comment passthrough, got %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user