mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-13 04:38:00 +08:00
refactor: unify empty-output retry logic into shared completionruntime package and normalize protocol adapter boundary.
This commit is contained in:
@@ -137,7 +137,7 @@ func (h *Handler) handleGeminiDirectStream(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
streamReq := start.Request
|
||||
h.handleStreamGenerateContent(w, r, start.Response, streamReq.ResponseModel, streamReq.PromptTokenText, streamReq.Thinking, streamReq.Search, streamReq.ToolNames, streamReq.ToolsRaw, historySession)
|
||||
h.handleStreamGenerateContentWithRetry(w, r, a, start.Response, start.Payload, start.Pow, streamReq.ResponseModel, streamReq.PromptTokenText, streamReq.Thinking, streamReq.Search, streamReq.ToolNames, streamReq.ToolsRaw, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
"ds2api/internal/responsehistory"
|
||||
"ds2api/internal/sse"
|
||||
@@ -54,7 +57,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: runtime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
runtime.finalize()
|
||||
runtime.finalize(false)
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -78,9 +81,83 @@ type geminiStreamRuntime struct {
|
||||
accumulator *assistantturn.Accumulator
|
||||
contentFilter bool
|
||||
responseMessageID int
|
||||
finalErrorStatus int
|
||||
finalErrorMessage string
|
||||
finalErrorCode string
|
||||
history *responsehistory.Session
|
||||
}
|
||||
|
||||
func (h *Handler) handleStreamGenerateContentWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *responsehistory.Session) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.Error(resp.StatusCode, strings.TrimSpace(string(body)), "error", "", "")
|
||||
}
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), toolNames, toolsRaw, historySession)
|
||||
|
||||
completionruntime.ExecuteStreamWithRetry(r.Context(), h.DS, a, resp, payload, pow, completionruntime.StreamRetryOptions{
|
||||
Surface: "gemini.generate_content",
|
||||
Stream: true,
|
||||
RetryEnabled: true,
|
||||
MaxAttempts: 3,
|
||||
UsagePrompt: finalPrompt,
|
||||
}, completionruntime.StreamRetryHooks{
|
||||
ConsumeAttempt: func(currentResp *http.Response, allowDeferEmpty bool) (bool, bool) {
|
||||
return h.consumeGeminiStreamAttempt(r.Context(), currentResp, runtime, thinkingEnabled, allowDeferEmpty)
|
||||
},
|
||||
Finalize: func(_ int) {
|
||||
runtime.finalize(false)
|
||||
},
|
||||
ParentMessageID: func() int {
|
||||
return runtime.responseMessageID
|
||||
},
|
||||
OnRetryPrompt: func(prompt string) {
|
||||
runtime.finalPrompt = prompt
|
||||
},
|
||||
OnRetryFailure: func(status int, message, _ string) {
|
||||
runtime.sendErrorChunk(status, strings.TrimSpace(message))
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) consumeGeminiStreamAttempt(ctx context.Context, resp *http.Response, runtime *geminiStreamRuntime, thinkingEnabled bool, allowDeferEmpty bool) (bool, bool) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: ctx,
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(dsprotocol.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(dsprotocol.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: dsprotocol.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: runtime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
},
|
||||
})
|
||||
terminalWritten := runtime.finalize(allowDeferEmpty)
|
||||
if terminalWritten {
|
||||
return true, false
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func newGeminiStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
@@ -127,6 +204,35 @@ func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) sendErrorChunk(status int, message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = http.StatusText(status)
|
||||
}
|
||||
errorStatus := "INVALID_ARGUMENT"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errorStatus = "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
errorStatus = "PERMISSION_DENIED"
|
||||
case http.StatusTooManyRequests:
|
||||
errorStatus = "RESOURCE_EXHAUSTED"
|
||||
case http.StatusNotFound:
|
||||
errorStatus = "NOT_FOUND"
|
||||
default:
|
||||
if status >= 500 {
|
||||
errorStatus = "INTERNAL"
|
||||
}
|
||||
}
|
||||
s.sendChunk(map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": msg,
|
||||
"status": errorStatus,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
@@ -192,7 +298,7 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
func (s *geminiStreamRuntime) finalize(deferEmptyOutput bool) bool {
|
||||
rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot()
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: rawText,
|
||||
@@ -211,6 +317,19 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
ToolsRaw: s.toolsRaw,
|
||||
})
|
||||
outcome := assistantturn.FinalizeTurn(turn, assistantturn.FinalizeOptions{})
|
||||
if outcome.ShouldFail {
|
||||
if deferEmptyOutput {
|
||||
s.finalErrorStatus = outcome.Error.Status
|
||||
s.finalErrorMessage = outcome.Error.Message
|
||||
s.finalErrorCode = outcome.Error.Code
|
||||
return false
|
||||
}
|
||||
if s.history != nil {
|
||||
s.history.Error(outcome.Error.Status, outcome.Error.Message, outcome.Error.Code, responsehistory.ThinkingForArchive(turn.RawThinking, turn.DetectionThinking, turn.Thinking), responsehistory.TextForArchive(turn.RawText, turn.Text))
|
||||
}
|
||||
s.sendErrorChunk(outcome.Error.Status, outcome.Error.Message)
|
||||
return true
|
||||
}
|
||||
if s.history != nil {
|
||||
s.history.Success(
|
||||
http.StatusOK,
|
||||
@@ -257,4 +376,5 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
"totalTokenCount": outcome.Usage.TotalTokens,
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user