refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages

This commit is contained in:
CJACK
2026-05-02 23:28:43 +08:00
parent eccd8c957b
commit dc5bffdf89
24 changed files with 1215 additions and 254 deletions

View File

@@ -5,8 +5,10 @@ import (
"net/http"
"strings"
"ds2api/internal/assistantturn"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/toolstream"
@@ -24,6 +26,7 @@ type chatStreamRuntime struct {
refFileTokens int
toolNames []string
toolsRaw any
toolChoice promptcompat.ToolChoicePolicy
thinkingEnabled bool
searchEnabled bool
@@ -89,6 +92,7 @@ func newChatStreamRuntime(
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
toolChoice promptcompat.ToolChoicePolicy,
bufferToolContent bool,
emitEarlyToolDeltas bool,
) *chatStreamRuntime {
@@ -102,6 +106,7 @@ func newChatStreamRuntime(
finalPrompt: finalPrompt,
toolNames: toolNames,
toolsRaw: toolsRaw,
toolChoice: toolChoice,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
stripReferenceMarkers: stripReferenceMarkers,
@@ -201,14 +206,33 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.finalErrorCode = ""
finalThinking := s.accumulator.Thinking.String()
finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String()
finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers)
s.finalThinking = finalThinking
s.finalText = finalText
detected := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames)
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
finalText := s.accumulator.Text.String()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: s.accumulator.RawText.String(),
VisibleText: finalText,
RawThinking: s.accumulator.RawThinking.String(),
VisibleThinking: finalThinking,
DetectionThinking: finalToolDetectionThinking,
ContentFilter: finishReason == "content_filter",
ResponseMessageID: s.responseMessageID,
AlreadyEmittedCalls: s.toolCallsEmitted,
AlreadyEmittedToolRaw: s.toolCallsDoneEmitted,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.finalPrompt,
RefFileTokens: s.refFileTokens,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
ToolChoice: s.toolChoice,
})
s.finalThinking = turn.Thinking
s.finalText = turn.Text
if len(turn.ToolCalls) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
s.sendDelta(map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(turn.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
})
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
@@ -237,11 +261,14 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
batch.flush()
}
if len(detected.Calls) > 0 || s.toolCallsEmitted {
if len(turn.ToolCalls) > 0 || s.toolCallsEmitted {
finishReason = "tool_calls"
}
if len(detected.Calls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(finalText) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking)
if len(turn.ToolCalls) == 0 && !s.toolCallsEmitted && strings.TrimSpace(turn.Text) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking)
if turn.Error != nil {
status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code
}
if deferEmptyOutput {
s.finalErrorStatus = status
s.finalErrorMessage = message
@@ -251,7 +278,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.sendFailedChunk(status, message, code)
return true
}
usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText, s.refFileTokens)
usage := chatUsageFromTurn(turn)
s.finalFinishReason = finishReason
s.finalUsage = usage
s.sendChunk(openaifmt.BuildChatStreamChunk(
@@ -265,6 +292,17 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
return true
}
func chatUsageFromTurn(turn assistantturn.Turn) map[string]any {
return map[string]any{
"prompt_tokens": turn.Usage.InputTokens,
"completion_tokens": turn.Usage.OutputTokens,
"total_tokens": turn.Usage.TotalTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": turn.Usage.ReasoningTokens,
},
}
}
func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}

View File

@@ -6,6 +6,8 @@ import (
"strings"
"testing"
"time"
"ds2api/internal/promptcompat"
)
func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
@@ -23,6 +25,7 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
true,
nil,
nil,
promptcompat.DefaultToolChoicePolicy(),
false,
false,
)
@@ -51,3 +54,34 @@ func TestChatStreamKeepAliveEmitsEmptyChoiceDataFrame(t *testing.T) {
t.Fatalf("expected empty choices heartbeat, got %#v", choices)
}
}
func TestChatStreamFinalizeEnforcesRequiredToolChoice(t *testing.T) {
rec := httptest.NewRecorder()
runtime := newChatStreamRuntime(
rec,
http.NewResponseController(rec),
true,
"chatcmpl-test",
time.Now().Unix(),
"deepseek-v4-flash",
"prompt",
false,
false,
true,
[]string{"Write"},
nil,
promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired},
true,
false,
)
if !runtime.finalize("stop", false) {
t.Fatalf("expected terminal error to be written")
}
if runtime.finalErrorCode != "tool_choice_violation" {
t.Fatalf("expected tool_choice_violation, got %q body=%s", runtime.finalErrorCode, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "tool_choice requires") {
t.Fatalf("expected tool choice error in stream body, got %s", rec.Body.String())
}
}

View File

@@ -7,10 +7,12 @@ import (
"strings"
"time"
"ds2api/internal/assistantturn"
"ds2api/internal/auth"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
@@ -26,6 +28,7 @@ type chatNonStreamResult struct {
body map[string]any
finishReason string
responseMessageID int
outputError *assistantturn.OutputError
}
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
@@ -86,35 +89,40 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
return chatNonStreamResult{}, false
}
result := sse.CollectStream(resp, thinkingEnabled, true)
stripReferenceMarkers := h.compatStripReferenceMarkers()
finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
if searchEnabled {
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
}
detected := detectAssistantToolCalls(result.Text, finalText, result.Thinking, result.ToolDetectionThinking, toolNames)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
turn := assistantturn.BuildTurnFromCollected(result, assistantturn.BuildOptions{
Model: model,
Prompt: usagePrompt,
SearchEnabled: searchEnabled,
StripReferenceMarkers: h.compatStripReferenceMarkers(),
ToolNames: toolNames,
ToolsRaw: toolsRaw,
})
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, turn.Thinking, turn.Text, turn.ToolCalls, toolsRaw)
return chatNonStreamResult{
rawThinking: result.Thinking,
rawText: result.Text,
thinking: finalThinking,
thinking: turn.Thinking,
toolDetectionThinking: result.ToolDetectionThinking,
text: finalText,
text: turn.Text,
contentFilter: result.ContentFilter,
detectedCalls: len(detected.Calls),
detectedCalls: len(turn.ToolCalls),
body: respBody,
finishReason: chatFinishReason(respBody),
responseMessageID: result.ResponseMessageID,
outputError: turn.Error,
}, true
}
func (h *Handler) finishChatNonStreamResult(w http.ResponseWriter, result chatNonStreamResult, attempts int, usagePrompt string, refFileTokens int, historySession *chatHistorySession) {
if result.detectedCalls == 0 && shouldWriteUpstreamEmptyOutputError(result.text, result.thinking) {
if result.detectedCalls == 0 && strings.TrimSpace(result.text) == "" {
status, message, code := upstreamEmptyOutputDetail(result.contentFilter, result.text, result.thinking)
if result.outputError != nil {
status, message, code = result.outputError.Status, result.outputError.Message, result.outputError.Code
}
if historySession != nil {
historySession.error(status, message, code, result.thinking, result.text)
}
writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter)
writeOpenAIErrorWithCode(w, status, message, code)
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "chat.completions", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter)
return
}
@@ -147,8 +155,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
strings.TrimSpace(result.thinking) == ""
}
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, historySession)
if !ok {
return
}
@@ -190,7 +198,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
}
}
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -216,6 +224,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
streamRuntime := newChatStreamRuntime(
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
toolChoice,
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
)
streamRuntime.refFileTokens = refFileTokens

View File

@@ -8,6 +8,7 @@ import (
"time"
"ds2api/internal/chathistory"
"ds2api/internal/promptcompat"
"ds2api/internal/stream"
)
@@ -48,6 +49,7 @@ func TestConsumeChatStreamAttemptMarksContextCancelledState(t *testing.T) {
true,
nil,
nil,
promptcompat.DefaultToolChoicePolicy(),
false,
false,
)

View File

@@ -80,6 +80,10 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string) {
shared.WriteOpenAIError(w, status, message)
}
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
shared.WriteOpenAIErrorWithCode(w, status, message, code)
}
func openAIErrorType(status int) string {
return shared.OpenAIErrorType(status)
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
@@ -76,44 +77,40 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
}
historySession := startChatHistory(h.ChatHistory, r, a, stdReq)
sessionID, err = h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
if !stdReq.Stream {
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
sessionID = result.SessionID
if outErr != nil {
if historySession != nil {
historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "")
historySession.error(outErr.Status, outErr.Message, outErr.Code, result.Turn.Thinking, result.Turn.Text)
}
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
if historySession != nil {
historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "")
}
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
respBody := openaifmt.BuildChatCompletionWithToolCalls(result.SessionID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw)
respBody["usage"] = chatUsageFromTurn(result.Turn)
finishReason := chatFinishReason(respBody)
if historySession != nil {
historySession.success(http.StatusOK, result.Turn.Thinking, result.Turn.Text, finishReason, chatUsageFromTurn(result.Turn))
}
writeJSON(w, http.StatusOK, respBody)
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
sessionID = start.SessionID
if outErr != nil {
if historySession != nil {
historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "")
historySession.error(outErr.Status, outErr.Message, outErr.Code, "", "")
}
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
if historySession != nil {
historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "")
}
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
refFileTokens := stdReq.RefFileTokens
if stdReq.Stream {
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
return
}
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
h.handleStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, historySession)
}
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
@@ -234,6 +231,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
stripReferenceMarkers,
toolNames,
toolsRaw,
promptcompat.DefaultToolChoicePolicy(),
bufferToolContent,
emitEarlyToolDeltas,
)

View File

@@ -1,7 +1,6 @@
package responses
import (
"context"
"io"
"net/http"
"strings"
@@ -10,129 +9,10 @@ import (
"ds2api/internal/auth"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/promptcompat"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/toolcall"
)
type responsesNonStreamResult struct {
rawThinking string
rawText string
thinking string
toolDetectionThinking string
text string
contentFilter bool
parsed toolcall.ToolCallParseResult
body map[string]any
responseMessageID int
}
func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
attempts := 0
currentResp := resp
usagePrompt := finalPrompt
accumulatedThinking := ""
accumulatedRawThinking := ""
accumulatedToolDetectionThinking := ""
for {
result, ok := h.collectResponsesNonStreamAttempt(w, currentResp, responseID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
if !ok {
return
}
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, result.thinking)
accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, result.rawThinking)
accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, result.toolDetectionThinking)
result.thinking = accumulatedThinking
result.rawThinking = accumulatedRawThinking
result.toolDetectionThinking = accumulatedToolDetectionThinking
result.parsed = detectAssistantToolCalls(result.rawText, result.text, result.rawThinking, result.toolDetectionThinking, toolNames)
result.body = openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, result.thinking, result.text, result.parsed.Calls, toolsRaw)
if refFileTokens > 0 {
addRefFileTokensToUsage(result.body, refFileTokens)
}
if !shouldRetryResponsesNonStream(result, attempts) {
h.finishResponsesNonStreamResult(w, result, attempts, owner, responseID, toolChoice, traceID)
return
}
attempts++
config.Logger.Info("[openai_empty_retry] attempting synthetic retry", "surface", "responses", "stream", false, "retry_attempt", attempts, "parent_message_id", result.responseMessageID)
retryPow, powErr := h.DS.GetPow(ctx, a, 3)
if powErr != nil {
config.Logger.Warn("[openai_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", powErr)
retryPow = pow
}
nextResp, err := h.DS.CallCompletion(ctx, a, clonePayloadForEmptyOutputRetry(payload, result.responseMessageID), retryPow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err)
return
}
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}
func (h *Handler) collectResponsesNonStreamAttempt(w http.ResponseWriter, resp *http.Response, responseID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (responsesNonStreamResult, bool) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return responsesNonStreamResult{}, false
}
result := sse.CollectStream(resp, thinkingEnabled, false)
stripReferenceMarkers := h.compatStripReferenceMarkers()
sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers)
sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers)
if searchEnabled {
sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks)
}
textParsed := detectAssistantToolCalls(result.Text, sanitizedText, result.Thinking, result.ToolDetectionThinking, toolNames)
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, model, usagePrompt, sanitizedThinking, sanitizedText, textParsed.Calls, toolsRaw)
return responsesNonStreamResult{
rawThinking: result.Thinking,
rawText: result.Text,
thinking: sanitizedThinking,
toolDetectionThinking: result.ToolDetectionThinking,
text: sanitizedText,
contentFilter: result.ContentFilter,
parsed: textParsed,
body: responseObj,
responseMessageID: result.ResponseMessageID,
}, true
}
func (h *Handler) finishResponsesNonStreamResult(w http.ResponseWriter, result responsesNonStreamResult, attempts int, owner, responseID string, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
if len(result.parsed.Calls) == 0 && writeUpstreamEmptyOutputError(w, result.text, result.thinking, result.contentFilter) {
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", "none", "content_filter", result.contentFilter)
return
}
logResponsesToolPolicyRejection(traceID, toolChoice, result.parsed, "text")
if toolChoice.IsRequired() && len(result.parsed.Calls) == 0 {
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
return
}
h.getResponseStore().put(owner, responseID, result.body)
writeJSON(w, http.StatusOK, result.body)
source := "first_attempt"
if attempts > 0 {
source = "synthetic_retry"
}
config.Logger.Info("[openai_empty_retry] completed", "surface", "responses", "stream", false, "retry_attempts", attempts, "success_source", source)
}
func shouldRetryResponsesNonStream(result responsesNonStreamResult, attempts int) bool {
return emptyOutputRetryEnabled() &&
attempts < emptyOutputRetryMaxAttempts() &&
!result.contentFilter &&
len(result.parsed.Calls) == 0 &&
strings.TrimSpace(result.text) == "" &&
strings.TrimSpace(result.thinking) == ""
}
func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, resp, owner, responseID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID)
if !ok {

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/completionruntime"
"ds2api/internal/config"
dsprotocol "ds2api/internal/deepseek/protocol"
openaifmt "ds2api/internal/format/openai"
@@ -92,34 +93,31 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if !stdReq.Stream {
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
StripReferenceMarkers: h.compatStripReferenceMarkers(),
RetryEnabled: true,
})
if outErr != nil {
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
responseObj := openaifmt.BuildResponseObjectWithToolCalls(responseID, stdReq.ResponseModel, result.Turn.Prompt, result.Turn.Thinking, result.Turn.Text, result.Turn.ToolCalls, stdReq.ToolsRaw)
responseObj["usage"] = responsesUsageFromTurn(result.Turn)
h.getResponseStore().put(owner, responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
return
}
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
refFileTokens := stdReq.RefFileTokens
if stdReq.Stream {
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
if outErr != nil {
writeOpenAIErrorWithCode(w, outErr.Status, outErr.Message, outErr.Code)
return
}
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
refFileTokens := stdReq.RefFileTokens
h.handleResponsesStreamWithRetry(w, r, a, start.Response, start.Payload, start.Pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, refFileTokens, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {

View File

@@ -1,6 +1,7 @@
package responses
import (
"ds2api/internal/assistantturn"
"ds2api/internal/toolcall"
"net/http"
"strings"
@@ -159,9 +160,29 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
finalThinking := s.accumulator.Thinking.String()
finalToolDetectionThinking := s.accumulator.ToolDetectionThinking.String()
finalText := cleanVisibleOutput(s.accumulator.Text.String(), s.stripReferenceMarkers)
textParsed := detectAssistantToolCalls(s.accumulator.RawText.String(), finalText, s.accumulator.RawThinking.String(), finalToolDetectionThinking, s.toolNames)
detected := textParsed.Calls
finalText := s.accumulator.Text.String()
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
RawText: s.accumulator.RawText.String(),
VisibleText: finalText,
RawThinking: s.accumulator.RawThinking.String(),
VisibleThinking: finalThinking,
DetectionThinking: finalToolDetectionThinking,
ContentFilter: finishReason == "content_filter",
ResponseMessageID: s.responseMessageID,
AlreadyEmittedCalls: s.toolCallsEmitted,
AlreadyEmittedToolRaw: s.toolCallsDoneEmitted,
}, assistantturn.BuildOptions{
Model: s.model,
Prompt: s.finalPrompt,
RefFileTokens: s.refFileTokens,
SearchEnabled: s.searchEnabled,
StripReferenceMarkers: s.stripReferenceMarkers,
ToolNames: s.toolNames,
ToolsRaw: s.toolsRaw,
ToolChoice: s.toolChoice,
})
textParsed := turn.ParsedToolCalls
detected := turn.ToolCalls
s.logToolPolicyRejections(textParsed)
if len(detected) > 0 {
@@ -173,12 +194,15 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
s.closeMessageItem()
if s.toolChoice.IsRequired() && len(detected) == 0 {
s.failResponse(http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
if turn.Error != nil && turn.Error.Code == "tool_choice_violation" {
s.failResponse(turn.Error.Status, turn.Error.Message, turn.Error.Code)
return true
}
if len(detected) == 0 && strings.TrimSpace(finalText) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", finalText, finalThinking)
if len(detected) == 0 && strings.TrimSpace(turn.Text) == "" {
status, message, code := upstreamEmptyOutputDetail(finishReason == "content_filter", turn.Text, turn.Thinking)
if turn.Error != nil {
status, message, code = turn.Error.Status, turn.Error.Message, turn.Error.Code
}
if deferEmptyOutput {
s.finalErrorStatus = status
s.finalErrorMessage = message
@@ -190,7 +214,7 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
}
s.closeIncompleteFunctionItems()
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
obj := s.buildCompletedResponseObject(turn.Thinking, turn.Text, detected)
if s.persistResponse != nil {
s.persistResponse(obj)
}
@@ -199,6 +223,14 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput
return true
}
func responsesUsageFromTurn(turn assistantturn.Turn) map[string]any {
return map[string]any{
"input_tokens": turn.Usage.InputTokens,
"output_tokens": turn.Usage.OutputTokens,
"total_tokens": turn.Usage.TotalTokens,
}
}
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed toolcall.ToolCallParseResult) {
logRejected := func(parsed toolcall.ToolCallParseResult, channel string) {
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)