refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages

This commit is contained in:
CJACK
2026-05-02 23:28:43 +08:00
parent eccd8c957b
commit dc5bffdf89
24 changed files with 1215 additions and 254 deletions

View File

@@ -33,6 +33,9 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
toolsRaw := convertGeminiTools(req["tools"])
finalPrompt, toolNames := promptcompat.BuildOpenAIPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled)
if len(toolNames) == 0 && len(toolsRaw) > 0 {
toolNames = []string{"__any_tool__"}
}
passThrough := collectGeminiPassThrough(req)
return promptcompat.StandardRequest{
@@ -42,6 +45,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
ResponseModel: requestedModel,
Messages: messagesRaw,
PromptTokenText: finalPrompt,
ToolsRaw: toolsRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,

View File

@@ -11,7 +11,11 @@ import (
"github.com/go-chi/chi/v5"
"ds2api/internal/assistantturn"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/httpapi/requestbody"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
"ds2api/internal/toolcall"
"ds2api/internal/translatorcliproxy"
@@ -21,14 +25,80 @@ import (
)
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
if h.OpenAI == nil {
writeGeminiError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
if isGeminiVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, stream) {
return
}
if h.proxyViaOpenAI(w, r, stream) {
if h.Auth == nil || h.DS == nil {
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, stream) {
return
}
writeGeminiError(w, http.StatusInternalServerError, "Gemini runtime backend unavailable.")
return
}
writeGeminiError(w, http.StatusBadGateway, "Failed to proxy Gemini request.")
if h.handleGeminiDirect(w, r, stream) {
return
}
writeGeminiError(w, http.StatusBadGateway, "Failed to handle Gemini request.")
}
func isGeminiVercelProxyRequest(r *http.Request) bool {
if r == nil || r.URL == nil {
return false
}
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
}
func (h *Handler) handleGeminiDirect(w http.ResponseWriter, r *http.Request, stream bool) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
} else {
writeGeminiError(w, http.StatusBadRequest, "invalid body")
}
return true
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
var req map[string]any
if err := json.Unmarshal(raw, &req); err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
return true
}
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, err.Error())
return true
}
a, err := h.Auth.Determine(r)
if err != nil {
writeGeminiError(w, http.StatusUnauthorized, err.Error())
return true
}
defer h.Auth.Release(a)
if stream {
h.handleGeminiDirectStream(w, r, a, stdReq)
return true
}
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
if outErr != nil {
writeGeminiError(w, outErr.Status, outErr.Message)
return true
}
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponseFromTurn(result.Turn))
return true
}
func (h *Handler) handleGeminiDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
if outErr != nil {
writeGeminiError(w, outErr.Status, outErr.Message)
return
}
h.handleStreamGenerateContent(w, r, start.Response, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
}
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
@@ -250,6 +320,48 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
}
}
func buildGeminiGenerateContentResponseFromTurn(turn assistantturn.Turn) map[string]any {
parts := buildGeminiPartsFromTurn(turn)
return map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
"finishReason": "STOP",
},
},
"modelVersion": turn.Model,
"usageMetadata": map[string]any{
"promptTokenCount": turn.Usage.InputTokens,
"candidatesTokenCount": turn.Usage.OutputTokens,
"totalTokenCount": turn.Usage.TotalTokens,
},
}
}
func buildGeminiPartsFromTurn(turn assistantturn.Turn) []map[string]any {
if len(turn.ToolCalls) > 0 {
parts := make([]map[string]any, 0, len(turn.ToolCalls))
for _, tc := range turn.ToolCalls {
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": tc.Name,
"args": tc.Input,
},
})
}
return parts
}
text := turn.Text
if text == "" {
text = turn.Thinking
}
return []map[string]any{{"text": text}}
}
//nolint:unused // retained for native Gemini non-stream handling path.
func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.CountPromptTokens(finalPrompt, model)

View File

@@ -7,13 +7,14 @@ import (
"strings"
"time"
"ds2api/internal/assistantturn"
dsprotocol "ds2api/internal/deepseek/protocol"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
//nolint:unused // retained for native Gemini stream handling path.
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -28,7 +29,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw)
initialType := "text"
if thinkingEnabled {
@@ -64,9 +65,11 @@ type geminiStreamRuntime struct {
bufferContent bool
stripReferenceMarkers bool
toolNames []string
toolsRaw any
thinking strings.Builder
text strings.Builder
accumulator *assistantturn.Accumulator
contentFilter bool
responseMessageID int
}
//nolint:unused // retained for native Gemini stream handling path.
@@ -80,6 +83,7 @@ func newGeminiStreamRuntime(
searchEnabled bool,
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
) *geminiStreamRuntime {
return &geminiStreamRuntime{
w: w,
@@ -92,6 +96,12 @@ func newGeminiStreamRuntime(
bufferContent: len(toolNames) > 0,
stripReferenceMarkers: stripReferenceMarkers,
toolNames: toolNames,
toolsRaw: toolsRaw,
accumulator: assistantturn.NewAccumulator(assistantturn.AccumulatorOptions{
ThinkingEnabled: thinkingEnabled,
SearchEnabled: searchEnabled,
StripReferenceMarkers: stripReferenceMarkers,
}),
}
}
@@ -111,32 +121,24 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ResponseMessageID > 0 {
s.responseMessageID = parsed.ResponseMessageID
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
if parsed.ContentFilter {
s.contentFilter = true
}
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
cleanedText := cleanVisibleOutput(p.Text, s.stripReferenceMarkers)
if cleanedText == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(cleanedText) {
continue
}
contentSeen = true
accumulated := s.accumulator.Apply(parsed)
for _, p := range accumulated.Parts {
if p.Type == "thinking" {
if s.thinkingEnabled {
if cleanedText != "" {
s.thinking.WriteString(cleanedText)
}
}
continue
}
if cleanedText == "" {
if p.RawText == "" || p.CitationOnly || p.VisibleText == "" {
continue
}
s.text.WriteString(cleanedText)
if s.bufferContent {
continue
}
@@ -146,23 +148,38 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
"index": 0,
"content": map[string]any{
"role": "model",
"parts": []map[string]any{{"text": cleanedText}},
"parts": []map[string]any{{"text": p.VisibleText}},
},
},
},
"modelVersion": s.model,
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
return streamengine.ParsedDecision{ContentSeen: accumulated.ContentSeen}
}
//nolint:unused // retained for native Gemini stream handling path.
func (s *geminiStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: rawText,
VisibleText: text,
RawThinking: rawThinking,
VisibleThinking: thinking,
DetectionThinking: detectionThinking,
ContentFilter: s.contentFilter,
ResponseMessageID: s.responseMessageID,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.finalPrompt,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
})
if s.bufferContent {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
parts := buildGeminiPartsFromTurn(turn)
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
@@ -190,7 +207,11 @@ func (s *geminiStreamRuntime) finalize() {
"finishReason": "STOP",
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText),
"modelVersion": s.model,
"usageMetadata": map[string]any{
"promptTokenCount": turn.Usage.InputTokens,
"candidatesTokenCount": turn.Usage.OutputTokens,
"totalTokenCount": turn.Usage.TotalTokens,
},
})
}