mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-08 10:25:28 +08:00
refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages
This commit is contained in:
@@ -4,13 +4,19 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/config"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/promptcompat"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
@@ -22,14 +28,90 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if h.OpenAI == nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
if isClaudeVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, h.Store) {
|
||||
return
|
||||
}
|
||||
if h.proxyViaOpenAI(w, r, h.Store) {
|
||||
if h.Auth == nil || h.DS == nil {
|
||||
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, h.Store) {
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Claude runtime backend unavailable.")
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusBadGateway, "Failed to proxy Claude request.")
|
||||
if h.handleClaudeDirect(w, r) {
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusBadGateway, "Failed to handle Claude request.")
|
||||
}
|
||||
|
||||
func isClaudeVercelProxyRequest(r *http.Request) bool {
|
||||
if r == nil || r.URL == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
|
||||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeDirect(w http.ResponseWriter, r *http.Request) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
} else {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
}
|
||||
return true
|
||||
}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return true
|
||||
}
|
||||
exposeThinking := false
|
||||
if enabled, ok := util.ResolveThinkingOverride(req); ok && enabled {
|
||||
exposeThinking = true
|
||||
} else if _, ok := util.ResolveThinkingOverride(req); !ok && !util.ToBool(req["stream"]) {
|
||||
req["thinking"] = map[string]any{"type": "enabled"}
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return true
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return true
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if norm.Standard.Stream {
|
||||
h.handleClaudeDirectStream(w, r, a, norm.Standard)
|
||||
return true
|
||||
}
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, norm.Standard, completionruntime.Options{
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
RetryEnabled: true,
|
||||
})
|
||||
if outErr != nil {
|
||||
writeClaudeError(w, outErr.Status, outErr.Message)
|
||||
return true
|
||||
}
|
||||
writeJSON(w, http.StatusOK, claudefmt.BuildMessageResponseFromTurn(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
norm.Standard.ResponseModel,
|
||||
result.Turn,
|
||||
exposeThinking,
|
||||
))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
|
||||
if outErr != nil {
|
||||
writeClaudeError(w, outErr.Status, outErr.Message)
|
||||
return
|
||||
}
|
||||
h.handleClaudeStreamRealtime(w, r, start.Response, stdReq.ResponseModel, stdReq.Messages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/toolstream"
|
||||
@@ -9,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
@@ -115,18 +115,28 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: s.rawText.String(),
|
||||
VisibleText: s.text.String(),
|
||||
RawThinking: s.rawThinking.String(),
|
||||
VisibleThinking: s.thinking.String(),
|
||||
DetectionThinking: s.toolDetectionThinking.String(),
|
||||
AlreadyEmittedCalls: s.toolCallsDetected,
|
||||
AlreadyEmittedToolRaw: s.toolCallsDetected,
|
||||
}, assistantturn.BuildOptions{
|
||||
Model: s.model,
|
||||
Prompt: s.promptTokenText,
|
||||
SearchEnabled: s.searchEnabled,
|
||||
StripReferenceMarkers: s.stripReferenceMarkers,
|
||||
ToolNames: s.toolNames,
|
||||
ToolsRaw: s.toolsRaw,
|
||||
})
|
||||
finalText := turn.Text
|
||||
|
||||
if s.bufferToolContent && !s.toolCallsDetected {
|
||||
detected := toolcall.ParseStandaloneToolCallsDetailed(s.rawText.String(), s.toolNames)
|
||||
if len(detected.Calls) == 0 {
|
||||
detected = toolcall.ParseStandaloneToolCallsDetailed(s.rawThinking.String(), s.toolNames)
|
||||
}
|
||||
if len(detected.Calls) > 0 {
|
||||
normalized := toolcall.NormalizeParsedToolCallsForSchemas(detected.Calls, s.toolsRaw)
|
||||
if len(turn.ToolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for _, tc := range normalized {
|
||||
for _, tc := range turn.ToolCalls {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.sendToolUseBlock(idx, tc)
|
||||
@@ -161,7 +171,6 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
|
||||
outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
@@ -169,7 +178,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
"output_tokens": turn.Usage.OutputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
|
||||
@@ -33,6 +33,9 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := promptcompat.BuildOpenAIPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled)
|
||||
if len(toolNames) == 0 && len(toolsRaw) > 0 {
|
||||
toolNames = []string{"__any_tool__"}
|
||||
}
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return promptcompat.StandardRequest{
|
||||
@@ -42,6 +45,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
|
||||
ResponseModel: requestedModel,
|
||||
Messages: messagesRaw,
|
||||
PromptTokenText: finalPrompt,
|
||||
ToolsRaw: toolsRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: stream,
|
||||
|
||||
@@ -11,7 +11,11 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
@@ -21,14 +25,80 @@ import (
|
||||
)
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
if h.OpenAI == nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
if isGeminiVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, stream) {
|
||||
return
|
||||
}
|
||||
if h.proxyViaOpenAI(w, r, stream) {
|
||||
if h.Auth == nil || h.DS == nil {
|
||||
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, stream) {
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Gemini runtime backend unavailable.")
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusBadGateway, "Failed to proxy Gemini request.")
|
||||
if h.handleGeminiDirect(w, r, stream) {
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusBadGateway, "Failed to handle Gemini request.")
|
||||
}
|
||||
|
||||
func isGeminiVercelProxyRequest(r *http.Request) bool {
|
||||
if r == nil || r.URL == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
|
||||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
}
|
||||
|
||||
func (h *Handler) handleGeminiDirect(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid body")
|
||||
}
|
||||
return true
|
||||
}
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return true
|
||||
}
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return true
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, err.Error())
|
||||
return true
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if stream {
|
||||
h.handleGeminiDirectStream(w, r, a, stdReq)
|
||||
return true
|
||||
}
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
RetryEnabled: true,
|
||||
})
|
||||
if outErr != nil {
|
||||
writeGeminiError(w, outErr.Status, outErr.Message)
|
||||
return true
|
||||
}
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponseFromTurn(result.Turn))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleGeminiDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
|
||||
if outErr != nil {
|
||||
writeGeminiError(w, outErr.Status, outErr.Message)
|
||||
return
|
||||
}
|
||||
h.handleStreamGenerateContent(w, r, start.Response, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
@@ -250,6 +320,48 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponseFromTurn(turn assistantturn.Turn) map[string]any {
|
||||
parts := buildGeminiPartsFromTurn(turn)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": turn.Model,
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": turn.Usage.InputTokens,
|
||||
"candidatesTokenCount": turn.Usage.OutputTokens,
|
||||
"totalTokenCount": turn.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiPartsFromTurn(turn assistantturn.Turn) []map[string]any {
|
||||
if len(turn.ToolCalls) > 0 {
|
||||
parts := make([]map[string]any, 0, len(turn.ToolCalls))
|
||||
for _, tc := range turn.ToolCalls {
|
||||
parts = append(parts, map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": tc.Input,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
text := turn.Text
|
||||
if text == "" {
|
||||
text = turn.Thinking
|
||||
}
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini non-stream handling path.
|
||||
func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.CountPromptTokens(finalPrompt, model)
|
||||
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -28,7 +29,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
@@ -64,9 +65,11 @@ type geminiStreamRuntime struct {
|
||||
bufferContent bool
|
||||
stripReferenceMarkers bool
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
accumulator *assistantturn.Accumulator
|
||||
contentFilter bool
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
@@ -80,6 +83,7 @@ func newGeminiStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
) *geminiStreamRuntime {
|
||||
return &geminiStreamRuntime{
|
||||
w: w,
|
||||
@@ -92,6 +96,12 @@ func newGeminiStreamRuntime(
|
||||
bufferContent: len(toolNames) > 0,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
accumulator: assistantturn.NewAccumulator(assistantturn.AccumulatorOptions{
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
SearchEnabled: searchEnabled,
|
||||
StripReferenceMarkers: stripReferenceMarkers,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,32 +121,24 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ResponseMessageID > 0 {
|
||||
s.responseMessageID = parsed.ResponseMessageID
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
if parsed.ContentFilter {
|
||||
s.contentFilter = true
|
||||
}
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
cleanedText := cleanVisibleOutput(p.Text, s.stripReferenceMarkers)
|
||||
if cleanedText == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(cleanedText) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
accumulated := s.accumulator.Apply(parsed)
|
||||
for _, p := range accumulated.Parts {
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
if cleanedText != "" {
|
||||
s.thinking.WriteString(cleanedText)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if cleanedText == "" {
|
||||
if p.RawText == "" || p.CitationOnly || p.VisibleText == "" {
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(cleanedText)
|
||||
if s.bufferContent {
|
||||
continue
|
||||
}
|
||||
@@ -146,23 +148,38 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{{"text": cleanedText}},
|
||||
"parts": []map[string]any{{"text": p.VisibleText}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
return streamengine.ParsedDecision{ContentSeen: accumulated.ContentSeen}
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
rawText, text, rawThinking, thinking, detectionThinking := s.accumulator.Snapshot()
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: rawText,
|
||||
VisibleText: text,
|
||||
RawThinking: rawThinking,
|
||||
VisibleThinking: thinking,
|
||||
DetectionThinking: detectionThinking,
|
||||
ContentFilter: s.contentFilter,
|
||||
ResponseMessageID: s.responseMessageID,
|
||||
}, assistantturn.BuildOptions{
|
||||
Model: s.model,
|
||||
Prompt: s.finalPrompt,
|
||||
SearchEnabled: s.searchEnabled,
|
||||
StripReferenceMarkers: s.stripReferenceMarkers,
|
||||
ToolNames: s.toolNames,
|
||||
ToolsRaw: s.toolsRaw,
|
||||
})
|
||||
|
||||
if s.bufferContent {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||
parts := buildGeminiPartsFromTurn(turn)
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
@@ -190,7 +207,11 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText),
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": turn.Usage.InputTokens,
|
||||
"candidatesTokenCount": turn.Usage.OutputTokens,
|
||||
"totalTokenCount": turn.Usage.TotalTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/toolstream"
|
||||
@@ -24,6 +26,7 @@ type chatStreamRuntime struct {
|
||||
refFileTokens int
|
||||
toolNames []string
|
||||
toolsRaw any
|
||||
toolChoice promptcompat.ToolChoicePolicy
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
@@ -89,6 +92,7 @@ func newChatStreamRuntime(
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
toolChoice promptcompat.ToolChoicePolicy,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
) *chatStreamRuntime {
|
||||
@@ -102,6 +106,7 @@ func newChatStreamRuntime(
|
||||
finalPrompt: finalPrompt,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
toolChoice: toolChoice,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
@@ -201,14 +206,33 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
s.finalErrorCode = ""
|
||||
finalThinking := s.accumulator.Thinking.String()
|
||||
finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String()
|
||||
finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers)
|
||||
s.finalThinking = finalThinking
|
||||
s.finalText = finalText
|
||||
detected := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finalText := s.accumulator.Text.String()
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: s.accumulator.RawText.String(),
|
||||
VisibleText: finalText,
|
||||
RawThinking: s.accumulator.RawThinking.String(),
|
||||
VisibleThinking: finalThinking,
|
||||
DetectionThinking: finalToolDetectionThinking,
|
||||
ContentFilter: finishReason == "content_filter",
|
||||
ResponseMessageID: s.responseMessageID,
|
||||
AlreadyEmittedCalls: s.toolCallsEmitted,
|
||||
AlreadyEmittedToolRaw: s.toolCallsDoneEmitted,
|
||||
}, assistantturn.BuildOptions{
|
||||
Model: s.model,
|
||||
Prompt: s.finalPrompt,
|
||||
RefFileTokens: s.refFileTokens,
|
||||
SearchEnabled: s.searchEnabled,
|
||||
StripReferenceMarkers: s.stripReferenceMarkers,
|
||||
ToolNames: s.toolNames,
|
||||
ToolsRaw: s.toolsRaw,
|
||||
ToolChoice: s.toolChoice,
|
||||
})
|
||||
s.finalThinking = turn.Thinking
|
||||
s.finalText = turn.Text
|
||||
if len(turn.ToolCalls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
s.sendDelta(map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(turn.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
|
||||
})
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
@@ -237,11 +261,14 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
batch.flush()
|
||||
}
|
||||
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
if len(turn.ToolCalls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
if len(detected.Calls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(finalText) == "" {
|
||||
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking)
|
||||
if len(turn.ToolCalls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(turn.Text) == "" {
|
||||
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking)
|
||||
if turn.Error != nil {
|
||||
status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code
|
||||
}
|
||||
if deferEmptyOutput {
|
||||
s.finalErrorStatus = status
|
||||
s.finalErrorMessage = message
|
||||
@@ -251,7 +278,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
s.sendFailedChunk(status, message, code)
|
||||
return true
|
||||
}
|
||||
usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText, s.refFileTokens)
|
||||
usage := chatUsageFromTurn(turn)
|
||||
s.finalFinishReason = finishReason
|
||||
s.finalUsage = usage
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
@@ -265,6 +292,17 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
return true
|
||||
}
|
||||
|
||||
func chatUsageFromTurn(turn assistantturn.Turn) map[string]any {
|
||||
return map[string]any{
|
||||
"prompt_tokens": turn.Usage.InputTokens,
|
||||
"completion_tokens": turn.Usage.OutputTokens,
|
||||
"total_tokens": turn.Usage.TotalTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": turn.Usage.ReasoningTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/promptcompat"
|
||||
)
|
||||
|
||||
func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
|
||||
@@ -23,6 +25,7 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
|
||||
true,
|
||||
nil,
|
||||
nil,
|
||||
promptcompat.DefaultToolChoicePolicy(),
|
||||
false,
|
||||
false,
|
||||
)
|
||||
@@ -51,3 +54,34 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
|
||||
t.Fatalf("expected empty choices heartbeat, got %#v", choices)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatStreamFinalizeEnforcesRequiredToolChoice(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
runtime := newChatStreamRuntime(
|
||||
rec,
|
||||
http.NewResponseController(rec),
|
||||
true,
|
||||
"chatcmpl-test",
|
||||
time.Now().Unix(),
|
||||
"deepseek-v4-flash",
|
||||
"prompt",
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
[]string{"Write"},
|
||||
nil,
|
||||
promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired},
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
if !runtime.finalize("stop", false) {
|
||||
t.Fatalf("expected terminal error to be written")
|
||||
}
|
||||
if runtime.finalErrorCode != "tool_choice_violation" {
|
||||
t.Fatalf("expected tool_choice_violation, got %q body=%s", runtime.finalErrorCode, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "tool_choice requires") {
|
||||
t.Fatalf("expected tool choice error in stream body, got %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
@@ -26,6 +28,7 @@ type chatNonStreamResult struct {
|
||||
body map[string]any
|
||||
finishReason string
|
||||
responseMessageID int
|
||||
outputError *assistantturn.OutputError
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
@@ -86,35 +89,40 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
|
||||
return chatNonStreamResult{}, false
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
|
||||
finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
|
||||
if searchEnabled {
|
||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||
}
|
||||
detected := detectAssistantToolCalls(result.Text, finalText, result.Thinking, result.ToolDetectionThinking, toolNames)
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
|
||||
turn := assistantturn.BuildTurnFromCollected(result, assistantturn.BuildOptions{
|
||||
Model: model,
|
||||
Prompt: usagePrompt,
|
||||
SearchEnabled: searchEnabled,
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
ToolNames: toolNames,
|
||||
ToolsRaw: toolsRaw,
|
||||
})
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, turn.Thinking, turn.Text, turn.ToolCalls, toolsRaw)
|
||||
return chatNonStreamResult{
|
||||
rawThinking: result.Thinking,
|
||||
rawText: result.Text,
|
||||
thinking: finalThinking,
|
||||
thinking: turn.Thinking,
|
||||
toolDetectionThinking: result.ToolDetectionThinking,
|
||||
text: finalText,
|
||||
text: turn.Text,
|
||||
contentFilter: result.ContentFilter,
|
||||
detectedCalls: len(detected.Calls),
|
||||
detectedCalls: len(turn.ToolCalls),
|
||||
body: respBody,
|
||||
finishReason: chatFinishReason(respBody),
|
||||
responseMessageID: result.ResponseMessageID,
|
||||
outputError: turn.Error,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (h *Handler) finishChatNonStreamResult(w http.ResponseWriter, result chatNonStreamResult, attempts int, usagePrompt string, refFileTokens int, historySession *chatHistorySession) {
|
||||
if result.detectedCalls == 0 && shouldWriteUpstreamEmptyOutputError(result.text, result.thinking) {
|
||||
if result.detectedCalls == 0 && strings.TrimSpace(result.text) == "" {
|
||||
status, message, code := upstreamEmptyOutputDetail(result.contentFilter, result.text, result.thinking)
|
||||
if result.outputError != nil {
|
||||
status, message, code = result.outputError.Status, result.outputError.Message, result.outputError.Code
|
||||
}
|
||||
if historySession != nil {
|
||||
historySession.error(status, message, code, result.thinking, result.text)
|
||||
}
|
||||
writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter)
|
||||
writeOpenAIErrorWithCode(w, status, message, code)
|
||||
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "chat.completions", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter)
|
||||
return
|
||||
}
|
||||
@@ -147,8 +155,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
|
||||
strings.TrimSpace(result.thinking) == ""
|
||||
}
|
||||
|
||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
|
||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
|
||||
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) {
|
||||
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, historySession)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -190,7 +198,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -216,6 +224,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
|
||||
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
|
||||
toolChoice,
|
||||
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
|
||||
)
|
||||
streamRuntime.refFileTokens = refFileTokens
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/stream"
|
||||
)
|
||||
|
||||
@@ -48,6 +49,7 @@ func TestConsumeChatStreamAttemptMarksContextCancelledState(t *testing.T) {
|
||||
true,
|
||||
nil,
|
||||
nil,
|
||||
promptcompat.DefaultToolChoicePolicy(),
|
||||
false,
|
||||
false,
|
||||
)
|
||||
|
||||
@@ -80,6 +80,10 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
shared.WriteOpenAIError(w, status, message)
|
||||
}
|
||||
|
||||
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
|
||||
shared.WriteOpenAIErrorWithCode(w, status, message, code)
|
||||
}
|
||||
|
||||
func openAIErrorType(status int) string {
|
||||
return shared.OpenAIErrorType(status)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/config"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
@@ -76,44 +77,40 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
historySession := startChatHistory(h.ChatHistory, r, a, stdReq)
|
||||
|
||||
sessionID, err = h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
if !stdReq.Stream {
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
RetryEnabled: true,
|
||||
})
|
||||
sessionID = result.SessionID
|
||||
if outErr != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "")
|
||||
historySession.error(outErr.Status, outErr.Message, outErr.Code, result.Turn.Thinking, result.Turn.Text)
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletionWithToolCalls(result.SessionID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw)
|
||||
respBody["usage"] = chatUsageFromTurn(result.Turn)
|
||||
finishReason := chatFinishReason(respBody)
|
||||
if historySession != nil {
|
||||
historySession.success(http.StatusOK, result.Turn.Thinking, result.Turn.Text, finishReason, chatUsageFromTurn(result.Turn))
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
|
||||
sessionID = start.SessionID
|
||||
if outErr != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "")
|
||||
historySession.error(outErr.Status, outErr.Message, outErr.Code, "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
|
||||
return
|
||||
}
|
||||
refFileTokens := stdReq.RefFileTokens
|
||||
if stdReq.Stream {
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
h.handleStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -234,6 +231,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
stripReferenceMarkers,
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
promptcompat.DefaultToolChoicePolicy(),
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -10,129 +9,10 @@ import (
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/toolcall"
|
||||
)
|
||||
|
||||
type responsesNonStreamResult struct {
|
||||
rawThinking string
|
||||
rawText string
|
||||
thinking string
|
||||
toolDetectionThinking string
|
||||
text string
|
||||
contentFilter bool
|
||||
parsed toolcall.ToolCallParseResult
|
||||
body map[string]any
|
||||
responseMessageID int
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
attempts := 0
|
||||
currentResp := resp
|
||||
usagePrompt := finalPrompt
|
||||
accumulatedThinking := ""
|
||||
accumulatedRawThinking := ""
|
||||
accumulatedToolDetectionThinking := ""
|
||||
for {
|
||||
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, result.thinking)
|
||||
accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, result.rawThinking)
|
||||
accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, result.toolDetectionThinking)
|
||||
result.thinking = accumulatedThinking
|
||||
result.rawThinking = accumulatedRawThinking
|
||||
result.toolDetectionThinking = accumulatedToolDetectionThinking
|
||||
result.parsed = detectAssistantToolCalls(result.rawText, result.text, result.rawThinking, result.toolDetectionThinking, toolNames)
|
||||
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
|
||||
if refFileTokens > 0 {
|
||||
addRefFileTokensToUsage(result.body, refFileTokens)
|
||||
}
|
||||
|
||||
if !shouldRetryResponsesNonStream(result, attempts) {
|
||||
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
|
||||
return
|
||||
}
|
||||
|
||||
attempts++
|
||||
config.Logger.Info("[openai_empty_retry] attempting synthetic retry", "surface", "responses", "stream", false, "retry_attempt", attempts, "parent_message_id", result.responseMessageID)
|
||||
retryPow, powErr := h.DS.GetPow(ctx, a, 3)
|
||||
if powErr != nil {
|
||||
config.Logger.Warn("[openai_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", powErr)
|
||||
retryPow = pow
|
||||
}
|
||||
nextResp, err := h.DS.CallCompletion(ctx, a, clonePayloadForEmptyOutputRetry(payload, result.responseMessageID), retryPow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err)
|
||||
return
|
||||
}
|
||||
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
|
||||
currentResp = nextResp
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return responsesNonStreamResult{}, false
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, false)
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
|
||||
sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
|
||||
if searchEnabled {
|
||||
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
|
||||
}
|
||||
textParsed := detectAssistantToolCalls(result.Text, sanitizedText, result.Thinking, result.ToolDetectionThinking, toolNames)
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
|
||||
return responsesNonStreamResult{
|
||||
rawThinking: result.Thinking,
|
||||
rawText: result.Text,
|
||||
thinking: sanitizedThinking,
|
||||
toolDetectionThinking: result.ToolDetectionThinking,
|
||||
text: sanitizedText,
|
||||
contentFilter: result.ContentFilter,
|
||||
parsed: textParsed,
|
||||
body: responseObj,
|
||||
responseMessageID: result.ResponseMessageID,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (h *Handler) finishResponsesNonStreamResult(w http.ResponseWriter, result responsesNonStreamResult, attempts int, owner, responseID string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
if len(result.parsed.Calls) == 0 && writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter) {
|
||||
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter)
|
||||
return
|
||||
}
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, result.parsed, "text")
|
||||
if toolChoice.IsRequired() && len(result.parsed.Calls) == 0 {
|
||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
return
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, result.body)
|
||||
writeJSON(w, http.StatusOK, result.body)
|
||||
source := "first_attempt"
|
||||
if attempts > 0 {
|
||||
source = "synthetic_retry"
|
||||
}
|
||||
config.Logger.Info("[openai_empty_retry] completed", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", source)
|
||||
}
|
||||
|
||||
func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int) bool {
|
||||
return emptyOutputRetryEnabled() &&
|
||||
attempts < emptyOutputRetryMaxAttempts() &&
|
||||
!result.contentFilter &&
|
||||
len(result.parsed.Calls) == 0 &&
|
||||
strings.TrimSpace(result.text) == "" &&
|
||||
strings.TrimSpace(result.thinking) == ""
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
|
||||
if !ok {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/config"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
@@ -92,34 +93,31 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if !stdReq.Stream {
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
RetryEnabled: true,
|
||||
})
|
||||
if outErr != nil {
|
||||
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw)
|
||||
responseObj["usage"] = responsesUsageFromTurn(result.Turn)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
return
|
||||
}
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
refFileTokens := stdReq.RefFileTokens
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
|
||||
if outErr != nil {
|
||||
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
|
||||
refFileTokens := stdReq.RefFileTokens
|
||||
h.handleResponsesStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/toolcall"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -159,9 +160,29 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
|
||||
|
||||
finalThinking := s.accumulator.Thinking.String()
|
||||
finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String()
|
||||
finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers)
|
||||
textParsed := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames)
|
||||
detected := textParsed.Calls
|
||||
finalText := s.accumulator.Text.String()
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: s.accumulator.RawText.String(),
|
||||
VisibleText: finalText,
|
||||
RawThinking: s.accumulator.RawThinking.String(),
|
||||
VisibleThinking: finalThinking,
|
||||
DetectionThinking: finalToolDetectionThinking,
|
||||
ContentFilter: finishReason == "content_filter",
|
||||
ResponseMessageID: s.responseMessageID,
|
||||
AlreadyEmittedCalls: s.toolCallsEmitted,
|
||||
AlreadyEmittedToolRaw: s.toolCallsDoneEmitted,
|
||||
}, assistantturn.BuildOptions{
|
||||
Model: s.model,
|
||||
Prompt: s.finalPrompt,
|
||||
RefFileTokens: s.refFileTokens,
|
||||
SearchEnabled: s.searchEnabled,
|
||||
StripReferenceMarkers: s.stripReferenceMarkers,
|
||||
ToolNames: s.toolNames,
|
||||
ToolsRaw: s.toolsRaw,
|
||||
ToolChoice: s.toolChoice,
|
||||
})
|
||||
textParsed := turn.ParsedToolCalls
|
||||
detected := turn.ToolCalls
|
||||
s.logToolPolicyRejections(textParsed)
|
||||
|
||||
if len(detected) > 0 {
|
||||
@@ -173,12 +194,15 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
|
||||
|
||||
s.closeMessageItem()
|
||||
|
||||
if s.toolChoice.IsRequired() && len(detected) == 0 {
|
||||
s.failResponse(http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
if turn.Error != nil && turn.Error.Code == "tool_choice_violation" {
|
||||
s.failResponse(turn.Error.Status, turn.Error.Message, turn.Error.Code)
|
||||
return true
|
||||
}
|
||||
if len(detected) == 0 && strings.TrimSpace(finalText) == "" {
|
||||
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking)
|
||||
if len(detected) == 0 && strings.TrimSpace(turn.Text) == "" {
|
||||
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking)
|
||||
if turn.Error != nil {
|
||||
status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code
|
||||
}
|
||||
if deferEmptyOutput {
|
||||
s.finalErrorStatus = status
|
||||
s.finalErrorMessage = message
|
||||
@@ -190,7 +214,7 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
|
||||
}
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
obj := s.buildCompletedResponseObject(turn.Thinking, turn.Text, detected)
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -199,6 +223,14 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
|
||||
return true
|
||||
}
|
||||
|
||||
func responsesUsageFromTurn(turn assistantturn.Turn) map[string]any {
|
||||
return map[string]any{
|
||||
"input_tokens": turn.Usage.InputTokens,
|
||||
"output_tokens": turn.Usage.OutputTokens,
|
||||
"total_tokens": turn.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed toolcall.ToolCallParseResult) {
|
||||
logRejected := func(parsed toolcall.ToolCallParseResult, channel string) {
|
||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||
|
||||
Reference in New Issue
Block a user