mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 08:25:26 +08:00
194 lines
8.0 KiB
Go
194 lines
8.0 KiB
Go
package completionruntime
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"ds2api/internal/assistantturn"
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/config"
|
|
dsclient "ds2api/internal/deepseek/client"
|
|
"ds2api/internal/httpapi/openai/history"
|
|
"ds2api/internal/httpapi/openai/shared"
|
|
"ds2api/internal/promptcompat"
|
|
"ds2api/internal/sse"
|
|
)
|
|
|
|
type DeepSeekCaller interface {
|
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
|
UploadFile(ctx context.Context, a *auth.RequestAuth, req dsclient.UploadFileRequest, maxAttempts int) (*dsclient.UploadFileResult, error)
|
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
|
}
|
|
|
|
type Options struct {
|
|
StripReferenceMarkers bool
|
|
MaxAttempts int
|
|
RetryEnabled bool
|
|
RetryMaxAttempts int
|
|
CurrentInputFile history.CurrentInputConfigReader
|
|
}
|
|
|
|
type NonStreamResult struct {
|
|
SessionID string
|
|
Payload map[string]any
|
|
Turn assistantturn.Turn
|
|
Attempts int
|
|
}
|
|
|
|
type StartResult struct {
|
|
SessionID string
|
|
Payload map[string]any
|
|
Pow string
|
|
Response *http.Response
|
|
Request promptcompat.StandardRequest
|
|
}
|
|
|
|
func StartCompletion(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (StartResult, *assistantturn.OutputError) {
|
|
maxAttempts := opts.MaxAttempts
|
|
if maxAttempts <= 0 {
|
|
maxAttempts = 3
|
|
}
|
|
var prepErr *assistantturn.OutputError
|
|
stdReq, prepErr = prepareCurrentInputFile(ctx, ds, a, stdReq, opts)
|
|
if prepErr != nil {
|
|
return StartResult{Request: stdReq}, prepErr
|
|
}
|
|
sessionID, err := ds.CreateSession(ctx, a, maxAttempts)
|
|
if err != nil {
|
|
return StartResult{Request: stdReq}, authOutputError(a)
|
|
}
|
|
pow, err := ds.GetPow(ctx, a, maxAttempts)
|
|
if err != nil {
|
|
return StartResult{SessionID: sessionID, Request: stdReq}, &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Failed to get PoW (invalid token or unknown error).", Code: "error"}
|
|
}
|
|
payload := stdReq.CompletionPayload(sessionID)
|
|
resp, err := ds.CallCompletion(ctx, a, payload, pow, maxAttempts)
|
|
if err != nil {
|
|
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow, Request: stdReq}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
|
|
}
|
|
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow, Response: resp, Request: stdReq}, nil
|
|
}
|
|
|
|
func prepareCurrentInputFile(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (promptcompat.StandardRequest, *assistantturn.OutputError) {
|
|
if opts.CurrentInputFile == nil || stdReq.CurrentInputFileApplied {
|
|
return stdReq, nil
|
|
}
|
|
out, err := (history.Service{Store: opts.CurrentInputFile, DS: ds}).ApplyCurrentInputFile(ctx, a, stdReq)
|
|
if err != nil {
|
|
status, message := history.MapError(err)
|
|
return out, &assistantturn.OutputError{Status: status, Message: message, Code: "error"}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func ExecuteNonStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (NonStreamResult, *assistantturn.OutputError) {
|
|
start, startErr := StartCompletion(ctx, ds, a, stdReq, opts)
|
|
if startErr != nil {
|
|
return NonStreamResult{SessionID: start.SessionID, Payload: start.Payload}, startErr
|
|
}
|
|
stdReq = start.Request
|
|
maxAttempts := opts.MaxAttempts
|
|
if maxAttempts <= 0 {
|
|
maxAttempts = 3
|
|
}
|
|
sessionID := start.SessionID
|
|
payload := start.Payload
|
|
pow := start.Pow
|
|
|
|
attempts := 0
|
|
currentResp := start.Response
|
|
usagePrompt := stdReq.PromptTokenText
|
|
accumulatedThinking := ""
|
|
accumulatedRawThinking := ""
|
|
accumulatedToolDetectionThinking := ""
|
|
for {
|
|
turn, outErr := collectAttempt(currentResp, stdReq, usagePrompt, opts)
|
|
if outErr != nil {
|
|
return NonStreamResult{SessionID: sessionID, Payload: payload, Attempts: attempts}, outErr
|
|
}
|
|
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, turn.Thinking)
|
|
accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, turn.RawThinking)
|
|
accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, turn.DetectionThinking)
|
|
turn.Thinking = accumulatedThinking
|
|
turn.RawThinking = accumulatedRawThinking
|
|
turn.DetectionThinking = accumulatedToolDetectionThinking
|
|
turn = assistantturn.BuildTurnFromCollected(sse.CollectResult{
|
|
Text: turn.RawText,
|
|
Thinking: turn.RawThinking,
|
|
ToolDetectionThinking: turn.DetectionThinking,
|
|
ContentFilter: turn.ContentFilter,
|
|
CitationLinks: turn.CitationLinks,
|
|
ResponseMessageID: turn.ResponseMessageID,
|
|
}, buildOptions(stdReq, usagePrompt, opts))
|
|
|
|
retryMax := opts.RetryMaxAttempts
|
|
if retryMax <= 0 {
|
|
retryMax = shared.EmptyOutputRetryMaxAttempts()
|
|
}
|
|
if !opts.RetryEnabled || !assistantturn.ShouldRetryEmptyOutput(turn, attempts, retryMax) {
|
|
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, turn.Error
|
|
}
|
|
|
|
attempts++
|
|
config.Logger.Info("[completion_runtime_empty_retry] attempting synthetic retry", "surface", stdReq.Surface, "stream", false, "retry_attempt", attempts, "parent_message_id", turn.ResponseMessageID)
|
|
retryPow, powErr := ds.GetPow(ctx, a, maxAttempts)
|
|
if powErr != nil {
|
|
config.Logger.Warn("[completion_runtime_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", stdReq.Surface, "retry_attempt", attempts, "error", powErr)
|
|
retryPow = pow
|
|
}
|
|
retryPayload := shared.ClonePayloadForEmptyOutputRetry(payload, turn.ResponseMessageID)
|
|
nextResp, err := ds.CallCompletion(ctx, a, retryPayload, retryPow, maxAttempts)
|
|
if err != nil {
|
|
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
|
|
}
|
|
usagePrompt = shared.UsagePromptWithEmptyOutputRetry(usagePrompt, attempts)
|
|
currentResp = nextResp
|
|
}
|
|
}
|
|
|
|
func collectAttempt(resp *http.Response, stdReq promptcompat.StandardRequest, usagePrompt string, opts Options) (assistantturn.Turn, *assistantturn.OutputError) {
|
|
defer func() {
|
|
if err := resp.Body.Close(); err != nil {
|
|
config.Logger.Warn("[completion_runtime] response body close failed", "surface", stdReq.Surface, "error", err)
|
|
}
|
|
}()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
message := strings.TrimSpace(string(body))
|
|
if message == "" {
|
|
message = http.StatusText(resp.StatusCode)
|
|
}
|
|
return assistantturn.Turn{}, &assistantturn.OutputError{Status: resp.StatusCode, Message: message, Code: "error"}
|
|
}
|
|
result := sse.CollectStream(resp, stdReq.Thinking, false)
|
|
return assistantturn.BuildTurnFromCollected(result, buildOptions(stdReq, usagePrompt, opts)), nil
|
|
}
|
|
|
|
func buildOptions(stdReq promptcompat.StandardRequest, prompt string, opts Options) assistantturn.BuildOptions {
|
|
return assistantturn.BuildOptions{
|
|
Model: stdReq.ResponseModel,
|
|
Prompt: prompt,
|
|
RefFileTokens: stdReq.RefFileTokens,
|
|
SearchEnabled: stdReq.Search,
|
|
StripReferenceMarkers: opts.StripReferenceMarkers,
|
|
ToolNames: stdReq.ToolNames,
|
|
ToolsRaw: stdReq.ToolsRaw,
|
|
ToolChoice: stdReq.ToolChoice,
|
|
}
|
|
}
|
|
|
|
func authOutputError(a *auth.RequestAuth) *assistantturn.OutputError {
|
|
if a != nil && a.UseConfigToken {
|
|
return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Account token is invalid. Please re-login the account in admin.", Code: "error"}
|
|
}
|
|
return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Invalid token. If this should be a DS2API key, add it to config.keys first.", Code: "error"}
|
|
}
|
|
|
|
func Errorf(status int, format string, args ...any) *assistantturn.OutputError {
|
|
return &assistantturn.OutputError{Status: status, Message: fmt.Sprintf(format, args...), Code: "error"}
|
|
}
|