refactor: unify empty-output retry logic into shared completionruntime package and normalize protocol adapter boundary.

This commit is contained in:
CJACK
2026-05-10 00:10:53 +08:00
parent 067cf465bb
commit 7c66742a19
32 changed files with 930 additions and 371 deletions

View File

@@ -1,6 +1,7 @@
package gemini
import (
"context"
"encoding/json"
"io"
"net/http"
@@ -8,6 +9,8 @@ import (
"time"
"ds2api/internal/assistantturn"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
dsprotocol "ds2api/internal/deepseek/protocol"
"ds2api/internal/responsehistory"
"ds2api/internal/sse"
@@ -54,7 +57,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req
}, streamengine.ConsumeHooks{
OnParsed: runtime.onParsed,
OnFinalize: func(_ streamengine.StopReason, _ error) {
runtime.finalize()
runtime.finalize(false)
},
})
}
@@ -78,9 +81,83 @@ type geminiStreamRuntime struct {
accumulator *assistantturn.Accumulator
contentFilter bool
responseMessageID int
finalErrorStatus int
finalErrorMessage string
finalErrorCode string
history *responsehistory.Session
}
func (h *Handler) handleStreamGenerateContentWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *responsehistory.Session) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if historySession != nil {
historySession.Error(resp.StatusCode, strings.TrimSpace(string(body)), "error", "", "")
}
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), toolNames, toolsRaw, historySession)
completionruntime.ExecuteStreamWithRetry(r.Context(), h.DS, a, resp, payload, pow, completionruntime.StreamRetryOptions{
Surface: "gemini.generate_content",
Stream: true,
RetryEnabled: true,
MaxAttempts: 3,
UsagePrompt: finalPrompt,
}, completionruntime.StreamRetryHooks{
ConsumeAttempt: func(currentResp *http.Response, allowDeferEmpty bool) (bool, bool) {
return h.consumeGeminiStreamAttempt(r.Context(), currentResp, runtime, thinkingEnabled, allowDeferEmpty)
},
Finalize: func(_ int) {
runtime.finalize(false)
},
ParentMessageID: func() int {
return runtime.responseMessageID
},
OnRetryPrompt: func(prompt string) {
runtime.finalPrompt = prompt
},
OnRetryFailure: func(status int, message, _ string) {
runtime.sendErrorChunk(status, strings.TrimSpace(message))
},
})
}
func (h *Handler) consumeGeminiStreamAttempt(ctx context.Context, resp *http.Response, runtime *geminiStreamRuntime, thinkingEnabled bool, allowDeferEmpty bool) (bool, bool) {
defer func() { _ = resp.Body.Close() }()
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: ctx,
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: time.Duration(dsprotocol.KeepAliveTimeout) * time.Second,
IdleTimeout: time.Duration(dsprotocol.StreamIdleTimeout) * time.Second,
MaxKeepAliveNoInput: dsprotocol.MaxKeepaliveCount,
}, streamengine.ConsumeHooks{
OnParsed: runtime.onParsed,
OnFinalize: func(_ streamengine.StopReason, _ error) {
},
})
terminalWritten := runtime.finalize(allowDeferEmpty)
if terminalWritten {
return true, false
}
return false, true
}
//nolint:unused // retained for native Gemini stream handling path.
func newGeminiStreamRuntime(
w http.ResponseWriter,
@@ -127,6 +204,35 @@ func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
}
}
func (s *geminiStreamRuntime) sendErrorChunk(status int, message string) {
msg := strings.TrimSpace(message)
if msg == "" {
msg = http.StatusText(status)
}
errorStatus := "INVALID_ARGUMENT"
switch status {
case http.StatusUnauthorized:
errorStatus = "UNAUTHENTICATED"
case http.StatusForbidden:
errorStatus = "PERMISSION_DENIED"
case http.StatusTooManyRequests:
errorStatus = "RESOURCE_EXHAUSTED"
case http.StatusNotFound:
errorStatus = "NOT_FOUND"
default:
if status >= 500 {
errorStatus = "INTERNAL"
}
}
s.sendChunk(map[string]any{
"error": map[string]any{
"code": status,
"message": msg,
"status": errorStatus,
},
})
}
//nolint:unused // retained for native Gemini stream handling path.
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
@@ -192,7 +298,7 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
}
//nolint:unused // retained for native Gemini stream handling path.
func (s *geminiStreamRuntime) finalize() {
func (s *geminiStreamRuntime) finalize(deferEmptyOutput bool) bool {
rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: rawText,
@@ -211,6 +317,19 @@ func (s *geminiStreamRuntime) finalize() {
ToolsRaw: s.toolsRaw,
})
outcome := assistantturn.FinalizeTurn(turn, assistantturn.FinalizeOptions{})
if outcome.ShouldFail {
if deferEmptyOutput {
s.finalErrorStatus = outcome.Error.Status
s.finalErrorMessage = outcome.Error.Message
s.finalErrorCode = outcome.Error.Code
return false
}
if s.history != nil {
s.history.Error(outcome.Error.Status, outcome.Error.Message, outcome.Error.Code, responsehistory.ThinkingForArchive(turn.RawThinking, turn.DetectionThinking, turn.Thinking), responsehistory.TextForArchive(turn.RawText, turn.Text))
}
s.sendErrorChunk(outcome.Error.Status, outcome.Error.Message)
return true
}
if s.history != nil {
s.history.Success(
http.StatusOK,
@@ -257,4 +376,5 @@ func (s *geminiStreamRuntime) finalize() {
"totalTokenCount": outcome.Usage.TotalTokens,
},
})
return true
}