mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-10 03:07:41 +08:00
271 lines
9.8 KiB
Go
271 lines
9.8 KiB
Go
package completionruntime
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
"ds2api/internal/account"
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/config"
|
|
dsclient "ds2api/internal/deepseek/client"
|
|
"ds2api/internal/promptcompat"
|
|
)
|
|
|
|
type fakeDeepSeekCaller struct {
|
|
responses []*http.Response
|
|
payloads []map[string]any
|
|
uploads []dsclient.UploadFileRequest
|
|
completionAccounts []string
|
|
sessionByAccount bool
|
|
}
|
|
|
|
type currentInputRuntimeConfig struct{}
|
|
|
|
func (currentInputRuntimeConfig) CurrentInputFileEnabled() bool { return true }
|
|
func (currentInputRuntimeConfig) CurrentInputFileMinChars() int { return 0 }
|
|
|
|
func (f *fakeDeepSeekCaller) CreateSession(_ context.Context, a *auth.RequestAuth, _ int) (string, error) {
|
|
if f.sessionByAccount && a != nil && a.AccountID != "" {
|
|
return "session-" + a.AccountID, nil
|
|
}
|
|
return "session-1", nil
|
|
}
|
|
|
|
func (f *fakeDeepSeekCaller) GetPow(context.Context, *auth.RequestAuth, int) (string, error) {
|
|
return "pow", nil
|
|
}
|
|
|
|
func (f *fakeDeepSeekCaller) UploadFile(_ context.Context, _ *auth.RequestAuth, req dsclient.UploadFileRequest, _ int) (*dsclient.UploadFileResult, error) {
|
|
f.uploads = append(f.uploads, req)
|
|
return &dsclient.UploadFileResult{ID: "file-runtime-1"}, nil
|
|
}
|
|
|
|
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, a *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
|
|
f.payloads = append(f.payloads, payload)
|
|
if a != nil {
|
|
f.completionAccounts = append(f.completionAccounts, a.AccountID)
|
|
}
|
|
if len(f.responses) == 0 {
|
|
return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil
|
|
}
|
|
resp := f.responses[0]
|
|
f.responses = f.responses[1:]
|
|
return resp, nil
|
|
}
|
|
|
|
func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) {
|
|
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
|
|
http.StatusOK,
|
|
`data: {"response_message_id":42,"p":"response/content","v":"<tool_calls><invoke name=\"Write\"><parameter name=\"content\">{\"x\":1}</parameter></invoke></tool_calls>"}`,
|
|
)}}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test",
|
|
ResponseModel: "deepseek-v4-flash",
|
|
PromptTokenText: "prompt",
|
|
FinalPrompt: "final prompt",
|
|
ToolNames: []string{"Write"},
|
|
ToolsRaw: []any{map[string]any{
|
|
"name": "Write",
|
|
"input_schema": map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"content": map[string]any{"type": "string"},
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error: %#v", outErr)
|
|
}
|
|
if result.SessionID != "session-1" {
|
|
t.Fatalf("session mismatch: %q", result.SessionID)
|
|
}
|
|
if got := result.Turn.ResponseMessageID; got != 42 {
|
|
t.Fatalf("response message id mismatch: %d", got)
|
|
}
|
|
if len(result.Turn.ToolCalls) != 1 {
|
|
t.Fatalf("expected one tool call, got %d", len(result.Turn.ToolCalls))
|
|
}
|
|
if _, ok := result.Turn.ToolCalls[0].Input["content"].(string); !ok {
|
|
t.Fatalf("expected schema-normalized string argument, got %#v", result.Turn.ToolCalls[0].Input["content"])
|
|
}
|
|
if result.Turn.Usage.InputTokens == 0 || result.Turn.Usage.TotalTokens == 0 {
|
|
t.Fatalf("expected usage to be populated, got %#v", result.Turn.Usage)
|
|
}
|
|
}
|
|
|
|
func TestExecuteNonStreamWithRetrySwitchesManagedAccountBeforeFinal429(t *testing.T) {
|
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
|
"keys":["managed-key"],
|
|
"accounts":[
|
|
{"email":"acc1@test.com","password":"pwd"},
|
|
{"email":"acc2@test.com","password":"pwd"}
|
|
]
|
|
}`)
|
|
store := config.LoadStore()
|
|
resolver := auth.NewResolver(store, account.NewPool(store), func(_ context.Context, acc config.Account) (string, error) {
|
|
return "token-" + acc.Identifier(), nil
|
|
})
|
|
req, _ := http.NewRequest(http.MethodPost, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer managed-key")
|
|
a, err := resolver.Determine(req)
|
|
if err != nil {
|
|
t.Fatalf("determine failed: %v", err)
|
|
}
|
|
defer resolver.Release(a)
|
|
|
|
ds := &fakeDeepSeekCaller{
|
|
sessionByAccount: true,
|
|
responses: []*http.Response{
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":11,"p":"response/thinking_content","v":"first empty"}`),
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":12,"p":"response/thinking_content","v":"retry empty"}`),
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":21,"p":"response/content","v":"ok from second account"}`),
|
|
},
|
|
}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test",
|
|
ResponseModel: "deepseek-v4-flash",
|
|
PromptTokenText: "prompt",
|
|
FinalPrompt: "final prompt",
|
|
Thinking: true,
|
|
}
|
|
|
|
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, a, stdReq, Options{RetryEnabled: true})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error after account switch retry: %#v", outErr)
|
|
}
|
|
if result.Turn.Text != "ok from second account" {
|
|
t.Fatalf("text mismatch after switch retry: %q", result.Turn.Text)
|
|
}
|
|
if result.SessionID != "session-acc2@test.com" {
|
|
t.Fatalf("expected switched account session, got %q", result.SessionID)
|
|
}
|
|
wantAccounts := []string{"acc1@test.com", "acc1@test.com", "acc2@test.com"}
|
|
if len(ds.completionAccounts) != len(wantAccounts) {
|
|
t.Fatalf("completion account count mismatch: got %v want %v", ds.completionAccounts, wantAccounts)
|
|
}
|
|
for i, want := range wantAccounts {
|
|
if ds.completionAccounts[i] != want {
|
|
t.Fatalf("completion account %d = %q want %q (all=%v)", i, ds.completionAccounts[i], want, ds.completionAccounts)
|
|
}
|
|
}
|
|
if got := ds.payloads[2]["chat_session_id"]; got != "session-acc2@test.com" {
|
|
t.Fatalf("switched payload session mismatch: %#v", got)
|
|
}
|
|
if prompt, _ := ds.payloads[2]["prompt"].(string); strings.Contains(prompt, "Previous reply had no visible output") {
|
|
t.Fatalf("expected fresh switched-account prompt without empty-output suffix, got %q", prompt)
|
|
}
|
|
}
|
|
|
|
func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) {
|
|
ds := &fakeDeepSeekCaller{responses: []*http.Response{
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/thinking_content","v":"plan"}`),
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":78,"p":"response/content","v":"ok"}`),
|
|
}}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test",
|
|
ResponseModel: "deepseek-v4-flash",
|
|
PromptTokenText: "prompt",
|
|
FinalPrompt: "final prompt",
|
|
}
|
|
|
|
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{RetryEnabled: true})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error: %#v", outErr)
|
|
}
|
|
if result.Attempts != 1 {
|
|
t.Fatalf("expected one retry, got %d", result.Attempts)
|
|
}
|
|
if len(ds.payloads) != 2 {
|
|
t.Fatalf("expected two completion calls, got %d", len(ds.payloads))
|
|
}
|
|
if got := ds.payloads[1]["parent_message_id"]; got != 77 {
|
|
t.Fatalf("retry parent_message_id mismatch: %#v", got)
|
|
}
|
|
if result.Turn.Text != "ok" {
|
|
t.Fatalf("retry text mismatch: %q", result.Turn.Text)
|
|
}
|
|
}
|
|
|
|
func TestExecuteNonStreamWithRetryConvertsReferenceMarkers(t *testing.T) {
|
|
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
|
|
http.StatusOK,
|
|
`data: {"p":"response/content","v":"答案[reference:0]。","citation":{"cite_index":0,"url":"https://example.com/ref"}}`,
|
|
)}}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test",
|
|
ResponseModel: "deepseek-v4-flash-search",
|
|
PromptTokenText: "prompt",
|
|
FinalPrompt: "final prompt",
|
|
Search: true,
|
|
}
|
|
|
|
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error: %#v", outErr)
|
|
}
|
|
want := "答案[0](https://example.com/ref)。"
|
|
if result.Turn.Text != want {
|
|
t.Fatalf("text mismatch: got %q want %q", result.Turn.Text, want)
|
|
}
|
|
}
|
|
|
|
func TestStartCompletionAppliesCurrentInputFileGlobally(t *testing.T) {
|
|
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"ok"}`)}}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test_adapter",
|
|
RequestedModel: "deepseek-v4-flash",
|
|
ResolvedModel: "deepseek-v4-flash",
|
|
ResponseModel: "deepseek-v4-flash",
|
|
PromptTokenText: "first user turn",
|
|
FinalPrompt: "first user turn",
|
|
Messages: []any{
|
|
map[string]any{"role": "user", "content": "first user turn"},
|
|
},
|
|
}
|
|
|
|
start, outErr := StartCompletion(context.Background(), ds, &auth.RequestAuth{DeepSeekToken: "token"}, stdReq, Options{
|
|
CurrentInputFile: currentInputRuntimeConfig{},
|
|
})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error: %#v", outErr)
|
|
}
|
|
if len(ds.uploads) != 1 {
|
|
t.Fatalf("expected current input upload, got %d", len(ds.uploads))
|
|
}
|
|
if got := ds.uploads[0].Filename; got != "DS2API_HISTORY.txt" {
|
|
t.Fatalf("upload filename=%q want DS2API_HISTORY.txt", got)
|
|
}
|
|
if len(ds.payloads) != 1 {
|
|
t.Fatalf("expected one completion payload, got %d", len(ds.payloads))
|
|
}
|
|
refIDs, _ := ds.payloads[0]["ref_file_ids"].([]any)
|
|
if len(refIDs) != 1 || refIDs[0] != "file-runtime-1" {
|
|
t.Fatalf("expected uploaded file id in ref_file_ids, got %#v", ds.payloads[0]["ref_file_ids"])
|
|
}
|
|
prompt, _ := ds.payloads[0]["prompt"].(string)
|
|
if !strings.Contains(prompt, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
|
|
t.Fatalf("expected continuation prompt, got %q", prompt)
|
|
}
|
|
if !start.Request.CurrentInputFileApplied || !strings.Contains(start.Request.PromptTokenText, "# DS2API_HISTORY.txt") {
|
|
t.Fatalf("expected prepared request to carry current input file state, got %#v", start.Request)
|
|
}
|
|
}
|
|
|
|
func sseHTTPResponse(status int, lines ...string) *http.Response {
|
|
body := strings.Join(lines, "\n")
|
|
if !strings.HasSuffix(body, "\n") {
|
|
body += "\n"
|
|
}
|
|
return &http.Response{
|
|
StatusCode: status,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
}
|
|
}
|