Files
ds2api/internal/completionruntime/nonstream_test.go

271 lines
9.8 KiB
Go

package completionruntime
import (
"context"
"io"
"net/http"
"strings"
"testing"
"ds2api/internal/account"
"ds2api/internal/auth"
"ds2api/internal/config"
dsclient "ds2api/internal/deepseek/client"
"ds2api/internal/promptcompat"
)
type fakeDeepSeekCaller struct {
responses []*http.Response
payloads []map[string]any
uploads []dsclient.UploadFileRequest
completionAccounts []string
sessionByAccount bool
}
type currentInputRuntimeConfig struct{}
func (currentInputRuntimeConfig) CurrentInputFileEnabled() bool { return true }
func (currentInputRuntimeConfig) CurrentInputFileMinChars() int { return 0 }
func (f *fakeDeepSeekCaller) CreateSession(_ context.Context, a *auth.RequestAuth, _ int) (string, error) {
if f.sessionByAccount && a != nil && a.AccountID != "" {
return "session-" + a.AccountID, nil
}
return "session-1", nil
}
func (f *fakeDeepSeekCaller) GetPow(context.Context, *auth.RequestAuth, int) (string, error) {
return "pow", nil
}
func (f *fakeDeepSeekCaller) UploadFile(_ context.Context, _ *auth.RequestAuth, req dsclient.UploadFileRequest, _ int) (*dsclient.UploadFileResult, error) {
f.uploads = append(f.uploads, req)
return &dsclient.UploadFileResult{ID: "file-runtime-1"}, nil
}
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, a *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
f.payloads = append(f.payloads, payload)
if a != nil {
f.completionAccounts = append(f.completionAccounts, a.AccountID)
}
if len(f.responses) == 0 {
return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil
}
resp := f.responses[0]
f.responses = f.responses[1:]
return resp, nil
}
func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) {
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
http.StatusOK,
`data: {"response_message_id":42,"p":"response/content","v":"<tool_calls><invoke name=\"Write\"><parameter name=\"content\">{\"x\":1}</parameter></invoke></tool_calls>"}`,
)}}
stdReq := promptcompat.StandardRequest{
Surface: "test",
ResponseModel: "deepseek-v4-flash",
PromptTokenText: "prompt",
FinalPrompt: "final prompt",
ToolNames: []string{"Write"},
ToolsRaw: []any{map[string]any{
"name": "Write",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
}},
}
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
if outErr != nil {
t.Fatalf("unexpected output error: %#v", outErr)
}
if result.SessionID != "session-1" {
t.Fatalf("session mismatch: %q", result.SessionID)
}
if got := result.Turn.ResponseMessageID; got != 42 {
t.Fatalf("response message id mismatch: %d", got)
}
if len(result.Turn.ToolCalls) != 1 {
t.Fatalf("expected one tool call, got %d", len(result.Turn.ToolCalls))
}
if _, ok := result.Turn.ToolCalls[0].Input["content"].(string); !ok {
t.Fatalf("expected schema-normalized string argument, got %#v", result.Turn.ToolCalls[0].Input["content"])
}
if result.Turn.Usage.InputTokens == 0 || result.Turn.Usage.TotalTokens == 0 {
t.Fatalf("expected usage to be populated, got %#v", result.Turn.Usage)
}
}
func TestExecuteNonStreamWithRetrySwitchesManagedAccountBeforeFinal429(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"acc1@test.com","password":"pwd"},
{"email":"acc2@test.com","password":"pwd"}
]
}`)
store := config.LoadStore()
resolver := auth.NewResolver(store, account.NewPool(store), func(_ context.Context, acc config.Account) (string, error) {
return "token-" + acc.Identifier(), nil
})
req, _ := http.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := resolver.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer resolver.Release(a)
ds := &fakeDeepSeekCaller{
sessionByAccount: true,
responses: []*http.Response{
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":11,"p":"response/thinking_content","v":"first empty"}`),
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":12,"p":"response/thinking_content","v":"retry empty"}`),
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":21,"p":"response/content","v":"ok from second account"}`),
},
}
stdReq := promptcompat.StandardRequest{
Surface: "test",
ResponseModel: "deepseek-v4-flash",
PromptTokenText: "prompt",
FinalPrompt: "final prompt",
Thinking: true,
}
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, a, stdReq, Options{RetryEnabled: true})
if outErr != nil {
t.Fatalf("unexpected output error after account switch retry: %#v", outErr)
}
if result.Turn.Text != "ok from second account" {
t.Fatalf("text mismatch after switch retry: %q", result.Turn.Text)
}
if result.SessionID != "session-acc2@test.com" {
t.Fatalf("expected switched account session, got %q", result.SessionID)
}
wantAccounts := []string{"acc1@test.com", "acc1@test.com", "acc2@test.com"}
if len(ds.completionAccounts) != len(wantAccounts) {
t.Fatalf("completion account count mismatch: got %v want %v", ds.completionAccounts, wantAccounts)
}
for i, want := range wantAccounts {
if ds.completionAccounts[i] != want {
t.Fatalf("completion account %d = %q want %q (all=%v)", i, ds.completionAccounts[i], want, ds.completionAccounts)
}
}
if got := ds.payloads[2]["chat_session_id"]; got != "session-acc2@test.com" {
t.Fatalf("switched payload session mismatch: %#v", got)
}
if prompt, _ := ds.payloads[2]["prompt"].(string); strings.Contains(prompt, "Previous reply had no visible output") {
t.Fatalf("expected fresh switched-account prompt without empty-output suffix, got %q", prompt)
}
}
func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) {
ds := &fakeDeepSeekCaller{responses: []*http.Response{
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/thinking_content","v":"plan"}`),
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":78,"p":"response/content","v":"ok"}`),
}}
stdReq := promptcompat.StandardRequest{
Surface: "test",
ResponseModel: "deepseek-v4-flash",
PromptTokenText: "prompt",
FinalPrompt: "final prompt",
}
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{RetryEnabled: true})
if outErr != nil {
t.Fatalf("unexpected output error: %#v", outErr)
}
if result.Attempts != 1 {
t.Fatalf("expected one retry, got %d", result.Attempts)
}
if len(ds.payloads) != 2 {
t.Fatalf("expected two completion calls, got %d", len(ds.payloads))
}
if got := ds.payloads[1]["parent_message_id"]; got != 77 {
t.Fatalf("retry parent_message_id mismatch: %#v", got)
}
if result.Turn.Text != "ok" {
t.Fatalf("retry text mismatch: %q", result.Turn.Text)
}
}
func TestExecuteNonStreamWithRetryConvertsReferenceMarkers(t *testing.T) {
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
http.StatusOK,
`data: {"p":"response/content","v":"答案[reference:0]。","citation":{"cite_index":0,"url":"https://example.com/ref"}}`,
)}}
stdReq := promptcompat.StandardRequest{
Surface: "test",
ResponseModel: "deepseek-v4-flash-search",
PromptTokenText: "prompt",
FinalPrompt: "final prompt",
Search: true,
}
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
if outErr != nil {
t.Fatalf("unexpected output error: %#v", outErr)
}
want := "答案[0](https://example.com/ref)。"
if result.Turn.Text != want {
t.Fatalf("text mismatch: got %q want %q", result.Turn.Text, want)
}
}
func TestStartCompletionAppliesCurrentInputFileGlobally(t *testing.T) {
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"ok"}`)}}
stdReq := promptcompat.StandardRequest{
Surface: "test_adapter",
RequestedModel: "deepseek-v4-flash",
ResolvedModel: "deepseek-v4-flash",
ResponseModel: "deepseek-v4-flash",
PromptTokenText: "first user turn",
FinalPrompt: "first user turn",
Messages: []any{
map[string]any{"role": "user", "content": "first user turn"},
},
}
start, outErr := StartCompletion(context.Background(), ds, &auth.RequestAuth{DeepSeekToken: "token"}, stdReq, Options{
CurrentInputFile: currentInputRuntimeConfig{},
})
if outErr != nil {
t.Fatalf("unexpected output error: %#v", outErr)
}
if len(ds.uploads) != 1 {
t.Fatalf("expected current input upload, got %d", len(ds.uploads))
}
if got := ds.uploads[0].Filename; got != "DS2API_HISTORY.txt" {
t.Fatalf("upload filename=%q want DS2API_HISTORY.txt", got)
}
if len(ds.payloads) != 1 {
t.Fatalf("expected one completion payload, got %d", len(ds.payloads))
}
refIDs, _ := ds.payloads[0]["ref_file_ids"].([]any)
if len(refIDs) != 1 || refIDs[0] != "file-runtime-1" {
t.Fatalf("expected uploaded file id in ref_file_ids, got %#v", ds.payloads[0]["ref_file_ids"])
}
prompt, _ := ds.payloads[0]["prompt"].(string)
if !strings.Contains(prompt, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation prompt, got %q", prompt)
}
if !start.Request.CurrentInputFileApplied || !strings.Contains(start.Request.PromptTokenText, "# DS2API_HISTORY.txt") {
t.Fatalf("expected prepared request to carry current input file state, got %#v", start.Request)
}
}
func sseHTTPResponse(status int, lines ...string) *http.Response {
body := strings.Join(lines, "\n")
if !strings.HasSuffix(body, "\n") {
body += "\n"
}
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}