mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-09 02:45:29 +08:00
refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages
This commit is contained in:
170
internal/completionruntime/nonstream.go
Normal file
170
internal/completionruntime/nonstream.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package completionruntime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
StripReferenceMarkers bool
|
||||
MaxAttempts int
|
||||
RetryEnabled bool
|
||||
RetryMaxAttempts int
|
||||
}
|
||||
|
||||
type NonStreamResult struct {
|
||||
SessionID string
|
||||
Payload map[string]any
|
||||
Turn assistantturn.Turn
|
||||
Attempts int
|
||||
}
|
||||
|
||||
type StartResult struct {
|
||||
SessionID string
|
||||
Payload map[string]any
|
||||
Pow string
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
func StartCompletion(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (StartResult, *assistantturn.OutputError) {
|
||||
maxAttempts := opts.MaxAttempts
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = 3
|
||||
}
|
||||
sessionID, err := ds.CreateSession(ctx, a, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{}, authOutputError(a)
|
||||
}
|
||||
pow, err := ds.GetPow(ctx, a, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{SessionID: sessionID}, &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Failed to get PoW (invalid token or unknown error).", Code: "error"}
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := ds.CallCompletion(ctx, a, payload, pow, maxAttempts)
|
||||
if err != nil {
|
||||
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
|
||||
}
|
||||
return StartResult{SessionID: sessionID, Payload: payload, Pow: pow, Response: resp}, nil
|
||||
}
|
||||
|
||||
func ExecuteNonStreamWithRetry(ctx context.Context, ds DeepSeekCaller, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, opts Options) (NonStreamResult, *assistantturn.OutputError) {
|
||||
start, startErr := StartCompletion(ctx, ds, a, stdReq, opts)
|
||||
if startErr != nil {
|
||||
return NonStreamResult{SessionID: start.SessionID, Payload: start.Payload}, startErr
|
||||
}
|
||||
maxAttempts := opts.MaxAttempts
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = 3
|
||||
}
|
||||
sessionID := start.SessionID
|
||||
payload := start.Payload
|
||||
pow := start.Pow
|
||||
|
||||
attempts := 0
|
||||
currentResp := start.Response
|
||||
usagePrompt := stdReq.PromptTokenText
|
||||
accumulatedThinking := ""
|
||||
accumulatedRawThinking := ""
|
||||
accumulatedToolDetectionThinking := ""
|
||||
for {
|
||||
turn, outErr := collectAttempt(currentResp, stdReq, usagePrompt, opts)
|
||||
if outErr != nil {
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Attempts: attempts}, outErr
|
||||
}
|
||||
accumulatedThinking += sse.TrimContinuationOverlap(accumulatedThinking, turn.Thinking)
|
||||
accumulatedRawThinking += sse.TrimContinuationOverlap(accumulatedRawThinking, turn.RawThinking)
|
||||
accumulatedToolDetectionThinking += sse.TrimContinuationOverlap(accumulatedToolDetectionThinking, turn.DetectionThinking)
|
||||
turn.Thinking = accumulatedThinking
|
||||
turn.RawThinking = accumulatedRawThinking
|
||||
turn.DetectionThinking = accumulatedToolDetectionThinking
|
||||
turn = assistantturn.BuildTurnFromCollected(sse.CollectResult{
|
||||
Text: turn.RawText,
|
||||
Thinking: turn.RawThinking,
|
||||
ToolDetectionThinking: turn.DetectionThinking,
|
||||
ContentFilter: turn.ContentFilter,
|
||||
CitationLinks: turn.CitationLinks,
|
||||
ResponseMessageID: turn.ResponseMessageID,
|
||||
}, buildOptions(stdReq, usagePrompt, opts))
|
||||
|
||||
retryMax := opts.RetryMaxAttempts
|
||||
if retryMax <= 0 {
|
||||
retryMax = shared.EmptyOutputRetryMaxAttempts()
|
||||
}
|
||||
if !opts.RetryEnabled || !assistantturn.ShouldRetryEmptyOutput(turn, attempts, retryMax) {
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, turn.Error
|
||||
}
|
||||
|
||||
attempts++
|
||||
config.Logger.Info("[completion_runtime_empty_retry] attempting synthetic retry", "surface", stdReq.Surface, "stream", false, "retry_attempt", attempts, "parent_message_id", turn.ResponseMessageID)
|
||||
retryPow, powErr := ds.GetPow(ctx, a, maxAttempts)
|
||||
if powErr != nil {
|
||||
config.Logger.Warn("[completion_runtime_empty_retry] retry PoW fetch failed, falling back to original PoW", "surface", stdReq.Surface, "retry_attempt", attempts, "error", powErr)
|
||||
retryPow = pow
|
||||
}
|
||||
retryPayload := shared.ClonePayloadForEmptyOutputRetry(payload, turn.ResponseMessageID)
|
||||
nextResp, err := ds.CallCompletion(ctx, a, retryPayload, retryPow, maxAttempts)
|
||||
if err != nil {
|
||||
return NonStreamResult{SessionID: sessionID, Payload: payload, Turn: turn, Attempts: attempts}, &assistantturn.OutputError{Status: http.StatusInternalServerError, Message: "Failed to get completion.", Code: "error"}
|
||||
}
|
||||
usagePrompt = shared.UsagePromptWithEmptyOutputRetry(usagePrompt, attempts)
|
||||
currentResp = nextResp
|
||||
}
|
||||
}
|
||||
|
||||
func collectAttempt(resp *http.Response, stdReq promptcompat.StandardRequest, usagePrompt string, opts Options) (assistantturn.Turn, *assistantturn.OutputError) {
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
config.Logger.Warn("[completion_runtime] response body close failed", "surface", stdReq.Surface, "error", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
message := strings.TrimSpace(string(body))
|
||||
if message == "" {
|
||||
message = http.StatusText(resp.StatusCode)
|
||||
}
|
||||
return assistantturn.Turn{}, &assistantturn.OutputError{Status: resp.StatusCode, Message: message, Code: "error"}
|
||||
}
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, false)
|
||||
return assistantturn.BuildTurnFromCollected(result, buildOptions(stdReq, usagePrompt, opts)), nil
|
||||
}
|
||||
|
||||
func buildOptions(stdReq promptcompat.StandardRequest, prompt string, opts Options) assistantturn.BuildOptions {
|
||||
return assistantturn.BuildOptions{
|
||||
Model: stdReq.ResponseModel,
|
||||
Prompt: prompt,
|
||||
RefFileTokens: stdReq.RefFileTokens,
|
||||
SearchEnabled: stdReq.Search,
|
||||
StripReferenceMarkers: opts.StripReferenceMarkers,
|
||||
ToolNames: stdReq.ToolNames,
|
||||
ToolsRaw: stdReq.ToolsRaw,
|
||||
ToolChoice: stdReq.ToolChoice,
|
||||
}
|
||||
}
|
||||
|
||||
func authOutputError(a *auth.RequestAuth) *assistantturn.OutputError {
|
||||
if a != nil && a.UseConfigToken {
|
||||
return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Account token is invalid. Please re-login the account in admin.", Code: "error"}
|
||||
}
|
||||
return &assistantturn.OutputError{Status: http.StatusUnauthorized, Message: "Invalid token. If this should be a DS2API key, add it to config.keys first.", Code: "error"}
|
||||
}
|
||||
|
||||
func Errorf(status int, format string, args ...any) *assistantturn.OutputError {
|
||||
return &assistantturn.OutputError{Status: status, Message: fmt.Sprintf(format, args...), Code: "error"}
|
||||
}
|
||||
120
internal/completionruntime/nonstream_test.go
Normal file
120
internal/completionruntime/nonstream_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package completionruntime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/promptcompat"
|
||||
)
|
||||
|
||||
type fakeDeepSeekCaller struct {
|
||||
responses []*http.Response
|
||||
payloads []map[string]any
|
||||
}
|
||||
|
||||
func (f *fakeDeepSeekCaller) CreateSession(context.Context, *auth.RequestAuth, int) (string, error) {
|
||||
return "session-1", nil
|
||||
}
|
||||
|
||||
func (f *fakeDeepSeekCaller) GetPow(context.Context, *auth.RequestAuth, int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, _ *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
f.payloads = append(f.payloads, payload)
|
||||
if len(f.responses) == 0 {
|
||||
return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil
|
||||
}
|
||||
resp := f.responses[0]
|
||||
f.responses = f.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) {
|
||||
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
|
||||
http.StatusOK,
|
||||
`data: {"response_message_id":42,"p":"response/content","v":"<tool_calls><invoke name=\"Write\"><parameter name=\"content\">{\"x\":1}</parameter></invoke></tool_calls>"}`,
|
||||
)}}
|
||||
stdReq := promptcompat.StandardRequest{
|
||||
Surface: "test",
|
||||
ResponseModel: "deepseek-v4-flash",
|
||||
PromptTokenText: "prompt",
|
||||
FinalPrompt: "final prompt",
|
||||
ToolNames: []string{"Write"},
|
||||
ToolsRaw: []any{map[string]any{
|
||||
"name": "Write",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
|
||||
if outErr != nil {
|
||||
t.Fatalf("unexpected output error: %#v", outErr)
|
||||
}
|
||||
if result.SessionID != "session-1" {
|
||||
t.Fatalf("session mismatch: %q", result.SessionID)
|
||||
}
|
||||
if got := result.Turn.ResponseMessageID; got != 42 {
|
||||
t.Fatalf("response message id mismatch: %d", got)
|
||||
}
|
||||
if len(result.Turn.ToolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %d", len(result.Turn.ToolCalls))
|
||||
}
|
||||
if _, ok := result.Turn.ToolCalls[0].Input["content"].(string); !ok {
|
||||
t.Fatalf("expected schema-normalized string argument, got %#v", result.Turn.ToolCalls[0].Input["content"])
|
||||
}
|
||||
if result.Turn.Usage.InputTokens == 0 || result.Turn.Usage.TotalTokens == 0 {
|
||||
t.Fatalf("expected usage to be populated, got %#v", result.Turn.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) {
|
||||
ds := &fakeDeepSeekCaller{responses: []*http.Response{
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/status","v":"FINISHED"}`),
|
||||
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":78,"p":"response/content","v":"ok"}`),
|
||||
}}
|
||||
stdReq := promptcompat.StandardRequest{
|
||||
Surface: "test",
|
||||
ResponseModel: "deepseek-v4-flash",
|
||||
PromptTokenText: "prompt",
|
||||
FinalPrompt: "final prompt",
|
||||
}
|
||||
|
||||
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{RetryEnabled: true})
|
||||
if outErr != nil {
|
||||
t.Fatalf("unexpected output error: %#v", outErr)
|
||||
}
|
||||
if result.Attempts != 1 {
|
||||
t.Fatalf("expected one retry, got %d", result.Attempts)
|
||||
}
|
||||
if len(ds.payloads) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(ds.payloads))
|
||||
}
|
||||
if got := ds.payloads[1]["parent_message_id"]; got != 77 {
|
||||
t.Fatalf("retry parent_message_id mismatch: %#v", got)
|
||||
}
|
||||
if result.Turn.Text != "ok" {
|
||||
t.Fatalf("retry text mismatch: %q", result.Turn.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func sseHTTPResponse(status int, lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
body += "\n"
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user