mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
121 lines
3.8 KiB
Go
121 lines
3.8 KiB
Go
package completionruntime
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/promptcompat"
|
|
)
|
|
|
|
type fakeDeepSeekCaller struct {
|
|
responses []*http.Response
|
|
payloads []map[string]any
|
|
}
|
|
|
|
func (f *fakeDeepSeekCaller) CreateSession(context.Context, *auth.RequestAuth, int) (string, error) {
|
|
return "session-1", nil
|
|
}
|
|
|
|
func (f *fakeDeepSeekCaller) GetPow(context.Context, *auth.RequestAuth, int) (string, error) {
|
|
return "pow", nil
|
|
}
|
|
|
|
func (f *fakeDeepSeekCaller) CallCompletion(_ context.Context, _ *auth.RequestAuth, payload map[string]any, _ string, _ int) (*http.Response, error) {
|
|
f.payloads = append(f.payloads, payload)
|
|
if len(f.responses) == 0 {
|
|
return sseHTTPResponse(http.StatusOK, `data: {"p":"response/content","v":"fallback"}`), nil
|
|
}
|
|
resp := f.responses[0]
|
|
f.responses = f.responses[1:]
|
|
return resp, nil
|
|
}
|
|
|
|
func TestExecuteNonStreamWithRetryBuildsCanonicalTurn(t *testing.T) {
|
|
ds := &fakeDeepSeekCaller{responses: []*http.Response{sseHTTPResponse(
|
|
http.StatusOK,
|
|
`data: {"response_message_id":42,"p":"response/content","v":"<tool_calls><invoke name=\"Write\"><parameter name=\"content\">{\"x\":1}</parameter></invoke></tool_calls>"}`,
|
|
)}}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test",
|
|
ResponseModel: "deepseek-v4-flash",
|
|
PromptTokenText: "prompt",
|
|
FinalPrompt: "final prompt",
|
|
ToolNames: []string{"Write"},
|
|
ToolsRaw: []any{map[string]any{
|
|
"name": "Write",
|
|
"input_schema": map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"content": map[string]any{"type": "string"},
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error: %#v", outErr)
|
|
}
|
|
if result.SessionID != "session-1" {
|
|
t.Fatalf("session mismatch: %q", result.SessionID)
|
|
}
|
|
if got := result.Turn.ResponseMessageID; got != 42 {
|
|
t.Fatalf("response message id mismatch: %d", got)
|
|
}
|
|
if len(result.Turn.ToolCalls) != 1 {
|
|
t.Fatalf("expected one tool call, got %d", len(result.Turn.ToolCalls))
|
|
}
|
|
if _, ok := result.Turn.ToolCalls[0].Input["content"].(string); !ok {
|
|
t.Fatalf("expected schema-normalized string argument, got %#v", result.Turn.ToolCalls[0].Input["content"])
|
|
}
|
|
if result.Turn.Usage.InputTokens == 0 || result.Turn.Usage.TotalTokens == 0 {
|
|
t.Fatalf("expected usage to be populated, got %#v", result.Turn.Usage)
|
|
}
|
|
}
|
|
|
|
func TestExecuteNonStreamWithRetryUsesParentMessageForEmptyRetry(t *testing.T) {
|
|
ds := &fakeDeepSeekCaller{responses: []*http.Response{
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":77,"p":"response/status","v":"FINISHED"}`),
|
|
sseHTTPResponse(http.StatusOK, `data: {"response_message_id":78,"p":"response/content","v":"ok"}`),
|
|
}}
|
|
stdReq := promptcompat.StandardRequest{
|
|
Surface: "test",
|
|
ResponseModel: "deepseek-v4-flash",
|
|
PromptTokenText: "prompt",
|
|
FinalPrompt: "final prompt",
|
|
}
|
|
|
|
result, outErr := ExecuteNonStreamWithRetry(context.Background(), ds, &auth.RequestAuth{}, stdReq, Options{RetryEnabled: true})
|
|
if outErr != nil {
|
|
t.Fatalf("unexpected output error: %#v", outErr)
|
|
}
|
|
if result.Attempts != 1 {
|
|
t.Fatalf("expected one retry, got %d", result.Attempts)
|
|
}
|
|
if len(ds.payloads) != 2 {
|
|
t.Fatalf("expected two completion calls, got %d", len(ds.payloads))
|
|
}
|
|
if got := ds.payloads[1]["parent_message_id"]; got != 77 {
|
|
t.Fatalf("retry parent_message_id mismatch: %#v", got)
|
|
}
|
|
if result.Turn.Text != "ok" {
|
|
t.Fatalf("retry text mismatch: %q", result.Turn.Text)
|
|
}
|
|
}
|
|
|
|
func sseHTTPResponse(status int, lines ...string) *http.Response {
|
|
body := strings.Join(lines, "\n")
|
|
if !strings.HasSuffix(body, "\n") {
|
|
body += "\n"
|
|
}
|
|
return &http.Response{
|
|
StatusCode: status,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
}
|
|
}
|