mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-14 05:05:09 +08:00
refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages
This commit is contained in:
@@ -33,6 +33,9 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := promptcompat.BuildOpenAIPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled)
|
||||
if len(toolNames) == 0 && len(toolsRaw) > 0 {
|
||||
toolNames = []string{"__any_tool__"}
|
||||
}
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return promptcompat.StandardRequest{
|
||||
@@ -42,6 +45,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
|
||||
ResponseModel: requestedModel,
|
||||
Messages: messagesRaw,
|
||||
PromptTokenText: finalPrompt,
|
||||
ToolsRaw: toolsRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: stream,
|
||||
|
||||
@@ -11,7 +11,11 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
@@ -21,14 +25,80 @@ import (
|
||||
)
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
if h.OpenAI == nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
if isGeminiVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, stream) {
|
||||
return
|
||||
}
|
||||
if h.proxyViaOpenAI(w, r, stream) {
|
||||
if h.Auth == nil || h.DS == nil {
|
||||
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, stream) {
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Gemini runtime backend unavailable.")
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusBadGateway, "Failed to proxy Gemini request.")
|
||||
if h.handleGeminiDirect(w, r, stream) {
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusBadGateway, "Failed to handle Gemini request.")
|
||||
}
|
||||
|
||||
func isGeminiVercelProxyRequest(r *http.Request) bool {
|
||||
if r == nil || r.URL == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
|
||||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
}
|
||||
|
||||
func (h *Handler) handleGeminiDirect(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid body")
|
||||
}
|
||||
return true
|
||||
}
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return true
|
||||
}
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return true
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, err.Error())
|
||||
return true
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if stream {
|
||||
h.handleGeminiDirectStream(w, r, a, stdReq)
|
||||
return true
|
||||
}
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
RetryEnabled: true,
|
||||
})
|
||||
if outErr != nil {
|
||||
writeGeminiError(w, outErr.Status, outErr.Message)
|
||||
return true
|
||||
}
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponseFromTurn(result.Turn))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleGeminiDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
|
||||
if outErr != nil {
|
||||
writeGeminiError(w, outErr.Status, outErr.Message)
|
||||
return
|
||||
}
|
||||
h.handleStreamGenerateContent(w, r, start.Response, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
@@ -250,6 +320,48 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponseFromTurn(turn assistantturn.Turn) map[string]any {
|
||||
parts := buildGeminiPartsFromTurn(turn)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": turn.Model,
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": turn.Usage.InputTokens,
|
||||
"candidatesTokenCount": turn.Usage.OutputTokens,
|
||||
"totalTokenCount": turn.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiPartsFromTurn(turn assistantturn.Turn) []map[string]any {
|
||||
if len(turn.ToolCalls) > 0 {
|
||||
parts := make([]map[string]any, 0, len(turn.ToolCalls))
|
||||
for _, tc := range turn.ToolCalls {
|
||||
parts = append(parts, map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": tc.Input,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
text := turn.Text
|
||||
if text == "" {
|
||||
text = turn.Thinking
|
||||
}
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini non-stream handling path.
|
||||
func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.CountPromptTokens(finalPrompt, model)
|
||||
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -28,7 +29,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
@@ -64,9 +65,11 @@ type geminiStreamRuntime struct {
|
||||
bufferContent bool
|
||||
stripReferenceMarkers bool
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
accumulator *assistantturn.Accumulator
|
||||
contentFilter bool
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
@@ -80,6 +83,7 @@ func newGeminiStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
) *geminiStreamRuntime {
|
||||
return &geminiStreamRuntime{
|
||||
w: w,
|
||||
@@ -92,6 +96,12 @@ func newGeminiStreamRuntime(
|
||||
bufferContent: len(toolNames) > 0,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
accumulator: assistantturn.NewAccumulator(assistantturn.AccumulatorOptions{
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
SearchEnabled: searchEnabled,
|
||||
StripReferenceMarkers: stripReferenceMarkers,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,32 +121,24 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ResponseMessageID > 0 {
|
||||
s.responseMessageID = parsed.ResponseMessageID
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
if parsed.ContentFilter {
|
||||
s.contentFilter = true
|
||||
}
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
cleanedText := cleanVisibleOutput(p.Text, s.stripReferenceMarkers)
|
||||
if cleanedText == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(cleanedText) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
accumulated := s.accumulator.Apply(parsed)
|
||||
for _, p := range accumulated.Parts {
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
if cleanedText != "" {
|
||||
s.thinking.WriteString(cleanedText)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if cleanedText == "" {
|
||||
if p.RawText == "" || p.CitationOnly || p.VisibleText == "" {
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(cleanedText)
|
||||
if s.bufferContent {
|
||||
continue
|
||||
}
|
||||
@@ -146,23 +148,38 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{{"text": cleanedText}},
|
||||
"parts": []map[string]any{{"text": p.VisibleText}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
return streamengine.ParsedDecision{ContentSeen: accumulated.ContentSeen}
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot()
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: rawText,
|
||||
VisibleText: text,
|
||||
RawThinking: rawThinking,
|
||||
VisibleThinking: thinking,
|
||||
DetectionThinking: detectionThinking,
|
||||
ContentFilter: s.contentFilter,
|
||||
ResponseMessageID: s.responseMessageID,
|
||||
}, assistantturn.BuildOptions{
|
||||
Model: s.model,
|
||||
Prompt: s.finalPrompt,
|
||||
SearchEnabled: s.searchEnabled,
|
||||
StripReferenceMarkers: s.stripReferenceMarkers,
|
||||
ToolNames: s.toolNames,
|
||||
ToolsRaw: s.toolsRaw,
|
||||
})
|
||||
|
||||
if s.bufferContent {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||
parts := buildGeminiPartsFromTurn(turn)
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
@@ -190,7 +207,11 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText),
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": turn.Usage.InputTokens,
|
||||
"candidatesTokenCount": turn.Usage.OutputTokens,
|
||||
"totalTokenCount": turn.Usage.TotalTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user