test: Introduce comprehensive edge case tests for various internal packages including SSE, Claude, Auth, Account, Config, Deepseek, Admin, and Util.

This commit is contained in:
CJACK
2026-02-18 16:52:16 +08:00
parent deec72416e
commit f2b10992cc
14 changed files with 3291 additions and 7 deletions

3
.gitignore vendored
View File

@@ -81,6 +81,9 @@ ds2api-tests
htmlcov/
.pytest_cache/
.tox/
*.coverprofile
coverage*.out
cover/
# Misc
*.pyc

View File

@@ -0,0 +1,249 @@
package account
import (
"context"
"sync"
"testing"
"time"
"ds2api/internal/config"
)
// ─── Pool edge cases ─────────────────────────────────────────────────
func TestPoolEmptyNoAccounts(t *testing.T) {
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2")
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
pool := NewPool(config.LoadStore())
if _, ok := pool.Acquire("", nil); ok {
t.Fatal("expected acquire to fail with no accounts")
}
status := pool.Status()
if total, ok := status["total"].(int); !ok || total != 0 {
t.Fatalf("unexpected total: %#v", status["total"])
}
}
func TestPoolReleaseNonExistentAccount(t *testing.T) {
pool := newPoolForTest(t, "2")
pool.Release("nonexistent@example.com") // should not panic
}
func TestPoolReleaseAlreadyReleased(t *testing.T) {
pool := newPoolForTest(t, "2")
acc, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected acquire success")
}
pool.Release(acc.Identifier())
pool.Release(acc.Identifier()) // double release should not panic
}
func TestPoolAcquireTargetNotFound(t *testing.T) {
pool := newPoolForTest(t, "2")
if _, ok := pool.Acquire("nonexistent@example.com", nil); ok {
t.Fatal("expected acquire to fail for non-existent target")
}
}
func TestPoolAcquireWithExclusionList(t *testing.T) {
pool := newPoolForTest(t, "2")
acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true})
if !ok {
t.Fatal("expected acquire success with exclusion")
}
if acc.Identifier() != "acc2@example.com" {
t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier())
}
pool.Release(acc.Identifier())
}
func TestPoolAcquireAllExcluded(t *testing.T) {
pool := newPoolForTest(t, "2")
if _, ok := pool.Acquire("", map[string]bool{
"acc1@example.com": true,
"acc2@example.com": true,
}); ok {
t.Fatal("expected acquire to fail when all accounts excluded")
}
}
func TestPoolStatusFields(t *testing.T) {
pool := newPoolForTest(t, "2")
status := pool.Status()
// Check all expected fields are present
for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} {
if _, ok := status[key]; !ok {
t.Fatalf("missing status field: %s", key)
}
}
}
func TestPoolStatusAccountDetails(t *testing.T) {
pool := newPoolForTest(t, "2")
acc, _ := pool.Acquire("acc1@example.com", nil)
status := pool.Status()
inUseAccounts, ok := status["in_use_accounts"].([]string)
if !ok {
t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"])
}
found := false
for _, id := range inUseAccounts {
if id == "acc1@example.com" {
found = true
break
}
}
if !found {
t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts)
}
if status["in_use"] != 1 {
t.Fatalf("expected 1 in_use, got %v", status["in_use"])
}
pool.Release(acc.Identifier())
}
func TestPoolAcquireWaitContextCancelled(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
// Exhaust the pool
first, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected first acquire to succeed")
}
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
var waitOK bool
go func() {
defer wg.Done()
_, waitOK = pool.AcquireWait(ctx, "", nil)
}()
// Wait until queued
waitForWaitingCount(t, pool, 1)
// Cancel context
cancel()
wg.Wait()
if waitOK {
t.Fatal("expected acquire to fail after context cancellation")
}
pool.Release(first.Identifier())
}
func TestPoolAcquireWaitTargetAccount(t *testing.T) {
pool := newPoolForTest(t, "1")
// Exhaust acc1
acc1, ok := pool.Acquire("acc1@example.com", nil)
if !ok {
t.Fatal("expected acquire acc1 success")
}
// Acquire acc2 directly (should succeed since acc2 is free)
ctx := context.Background()
acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil)
if !ok {
t.Fatal("expected acquire acc2 success via AcquireWait")
}
if acc2.Identifier() != "acc2@example.com" {
t.Fatalf("expected acc2, got %q", acc2.Identifier())
}
pool.Release(acc1.Identifier())
pool.Release(acc2.Identifier())
}
func TestPoolMaxQueueSizeOverride(t *testing.T) {
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
pool := NewPool(config.LoadStore())
status := pool.Status()
if got, ok := status["max_queue_size"].(int); !ok || got != 5 {
t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"])
}
}
func TestPoolQueueSizeAliasEnv(t *testing.T) {
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7")
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
pool := NewPool(config.LoadStore())
status := pool.Status()
if got, ok := status["max_queue_size"].(int); !ok || got != 7 {
t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"])
}
}
func TestPoolMultipleAcquireReleaseCycles(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
for i := 0; i < 10; i++ {
acc, ok := pool.Acquire("", nil)
if !ok {
t.Fatalf("acquire failed at cycle %d", i)
}
pool.Release(acc.Identifier())
}
}
func TestPoolConcurrentAcquireWait(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
first, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected first acquire success")
}
const waiters = 3
results := make(chan bool, waiters)
for i := 0; i < waiters; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, ok := pool.AcquireWait(ctx, "", nil)
results <- ok
}()
}
// Wait for all to be queued (only 1 can queue)
time.Sleep(50 * time.Millisecond)
// Release and allow queued requests to proceed
pool.Release(first.Identifier())
successCount := 0
timeoutCount := 0
for i := 0; i < waiters; i++ {
select {
case ok := <-results:
if ok {
successCount++
// Release for next waiter
pool.Release("acc1@example.com")
} else {
timeoutCount++
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for results")
}
}
// At least 1 should succeed; 2 may fail due to queue limit
if successCount < 1 {
t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount)
}
}

View File

@@ -0,0 +1,348 @@
package claude
import (
"testing"
)
// ─── normalizeClaudeMessages ─────────────────────────────────────────
func TestNormalizeClaudeMessagesSimpleString(t *testing.T) {
msgs := []any{
map[string]any{"role": "user", "content": "Hello"},
}
got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected 1 message, got %d", len(got))
}
m := got[0].(map[string]any)
if m["content"] != "Hello" {
t.Fatalf("expected 'Hello', got %v", m["content"])
}
}
func TestNormalizeClaudeMessagesArrayContent(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "line1"},
map[string]any{"type": "text", "text": "line2"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "line1\nline2" {
t.Fatalf("expected joined text, got %q", m["content"])
}
}
func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "tool_result", "content": "tool output"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "tool output" {
t.Fatalf("expected 'tool output', got %q", m["content"])
}
}
func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) {
msgs := []any{"not a map", 42}
got := normalizeClaudeMessages(msgs)
if len(got) != 0 {
t.Fatalf("expected 0 messages for non-map items, got %d", len(got))
}
}
func TestNormalizeClaudeMessagesEmpty(t *testing.T) {
got := normalizeClaudeMessages(nil)
if len(got) != 0 {
t.Fatalf("expected 0, got %d", len(got))
}
}
func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) {
msgs := []any{
map[string]any{"role": "assistant", "content": "response"},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["role"] != "assistant" {
t.Fatalf("expected 'assistant', got %q", m["role"])
}
}
func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "Hello"},
map[string]any{"type": "image", "source": "data:..."},
map[string]any{"type": "text", "text": "World"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "Hello\nWorld" {
t.Fatalf("expected only text parts joined, got %q", m["content"])
}
}
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
tools := []any{
map[string]any{
"name": "search",
"description": "Search the web",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
},
}
prompt := buildClaudeToolPrompt(tools)
if prompt == "" {
t.Fatal("expected non-empty prompt")
}
// Should contain tool name and description
if !containsStr(prompt, "search") {
t.Fatalf("expected 'search' in prompt")
}
if !containsStr(prompt, "Search the web") {
t.Fatalf("expected description in prompt")
}
if !containsStr(prompt, "tool_calls") {
t.Fatalf("expected tool_calls instruction in prompt")
}
}
func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
tools := []any{
map[string]any{"name": "tool1", "description": "desc1"},
map[string]any{"name": "tool2", "description": "desc2"},
}
prompt := buildClaudeToolPrompt(tools)
if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") {
t.Fatalf("expected both tools in prompt")
}
}
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
tools := []any{"not a map"}
prompt := buildClaudeToolPrompt(tools)
if prompt == "" {
t.Fatal("expected non-empty prompt even with invalid tools")
}
// Should still contain the intro and instruction
if !containsStr(prompt, "You are Claude") {
t.Fatalf("expected intro in prompt")
}
}
// ─── hasSystemMessage ────────────────────────────────────────────────
func TestHasSystemMessageTrue(t *testing.T) {
msgs := []any{
map[string]any{"role": "system", "content": "You are a helper"},
map[string]any{"role": "user", "content": "Hi"},
}
if !hasSystemMessage(msgs) {
t.Fatal("expected true")
}
}
func TestHasSystemMessageFalse(t *testing.T) {
msgs := []any{
map[string]any{"role": "user", "content": "Hi"},
map[string]any{"role": "assistant", "content": "Hello"},
}
if hasSystemMessage(msgs) {
t.Fatal("expected false")
}
}
func TestHasSystemMessageEmpty(t *testing.T) {
if hasSystemMessage(nil) {
t.Fatal("expected false for nil")
}
}
func TestHasSystemMessageNonMap(t *testing.T) {
msgs := []any{"not a map"}
if hasSystemMessage(msgs) {
t.Fatal("expected false for non-map")
}
}
// ─── extractClaudeToolNames ──────────────────────────────────────────
func TestExtractClaudeToolNamesSingle(t *testing.T) {
tools := []any{
map[string]any{"name": "search"},
}
names := extractClaudeToolNames(tools)
if len(names) != 1 || names[0] != "search" {
t.Fatalf("expected [search], got %v", names)
}
}
func TestExtractClaudeToolNamesMultiple(t *testing.T) {
tools := []any{
map[string]any{"name": "search"},
map[string]any{"name": "calculate"},
}
names := extractClaudeToolNames(tools)
if len(names) != 2 {
t.Fatalf("expected 2 names, got %v", names)
}
}
func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) {
tools := []any{
map[string]any{"name": ""},
map[string]any{"name": "valid"},
}
names := extractClaudeToolNames(tools)
if len(names) != 1 || names[0] != "valid" {
t.Fatalf("expected [valid], got %v", names)
}
}
func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) {
tools := []any{"not a map", 42}
names := extractClaudeToolNames(tools)
if len(names) != 0 {
t.Fatalf("expected 0, got %v", names)
}
}
func TestExtractClaudeToolNamesNil(t *testing.T) {
names := extractClaudeToolNames(nil)
if len(names) != 0 {
t.Fatalf("expected 0, got %v", names)
}
}
// ─── toMessageMaps ───────────────────────────────────────────────────
func TestToMessageMapsNormal(t *testing.T) {
input := []any{
map[string]any{"role": "user", "content": "Hello"},
}
got := toMessageMaps(input)
if len(got) != 1 {
t.Fatalf("expected 1, got %d", len(got))
}
}
func TestToMessageMapsNonSlice(t *testing.T) {
got := toMessageMaps("not a slice")
if got != nil {
t.Fatalf("expected nil, got %v", got)
}
}
func TestToMessageMapsSkipsNonMap(t *testing.T) {
input := []any{"string", map[string]any{"role": "user"}, 42}
got := toMessageMaps(input)
if len(got) != 1 {
t.Fatalf("expected 1 map, got %d", len(got))
}
}
func TestToMessageMapsNil(t *testing.T) {
got := toMessageMaps(nil)
if got != nil {
t.Fatalf("expected nil, got %v", got)
}
}
// ─── extractMessageContent ──────────────────────────────────────────
func TestExtractMessageContentString(t *testing.T) {
if got := extractMessageContent("hello"); got != "hello" {
t.Fatalf("expected 'hello', got %q", got)
}
}
func TestExtractMessageContentArray(t *testing.T) {
input := []any{"part1", "part2"}
got := extractMessageContent(input)
if got != "part1\npart2" {
t.Fatalf("expected joined, got %q", got)
}
}
func TestExtractMessageContentOther(t *testing.T) {
got := extractMessageContent(42)
if got != "42" {
t.Fatalf("expected '42', got %q", got)
}
}
func TestExtractMessageContentNil(t *testing.T) {
got := extractMessageContent(nil)
if got != "<nil>" {
t.Fatalf("expected '<nil>', got %q", got)
}
}
// ─── cloneMap ────────────────────────────────────────────────────────
func TestCloneMapBasic(t *testing.T) {
original := map[string]any{"a": 1, "b": "hello"}
clone := cloneMap(original)
original["a"] = 999
if clone["a"] != 1 {
t.Fatalf("expected 1, got %v", clone["a"])
}
if clone["b"] != "hello" {
t.Fatalf("expected 'hello', got %v", clone["b"])
}
}
func TestCloneMapEmpty(t *testing.T) {
clone := cloneMap(map[string]any{})
if len(clone) != 0 {
t.Fatalf("expected empty, got %v", clone)
}
}
func TestCloneMapNested(t *testing.T) {
// cloneMap is shallow, so nested maps share references
inner := map[string]any{"key": "value"}
original := map[string]any{"nested": inner}
clone := cloneMap(original)
// Shallow clone means inner is shared
inner["key"] = "modified"
cloneNested := clone["nested"].(map[string]any)
if cloneNested["key"] != "modified" {
t.Fatal("expected shallow clone to share nested references")
}
}
// helper
func containsStr(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub))
}
func findSubstring(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}

View File

@@ -120,10 +120,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
return
}
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, toolNames)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)

View File

@@ -128,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -161,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -189,7 +189,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -220,7 +220,7 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
@@ -249,7 +249,7 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, false, []string{"search"})
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}

View File

@@ -0,0 +1,240 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"ds2api/internal/config"
)
// ─── reverseAccounts ─────────────────────────────────────────────────
func TestReverseAccountsEmpty(t *testing.T) {
a := []config.Account{}
reverseAccounts(a)
if len(a) != 0 {
t.Fatal("expected empty")
}
}
func TestReverseAccountsTwoElements(t *testing.T) {
a := []config.Account{
{Email: "a@test.com"},
{Email: "b@test.com"},
}
reverseAccounts(a)
if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" {
t.Fatalf("unexpected order after reverse: %v", a)
}
}
func TestReverseAccountsThreeElements(t *testing.T) {
a := []config.Account{
{Email: "1@test.com"},
{Email: "2@test.com"},
{Email: "3@test.com"},
}
reverseAccounts(a)
if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" {
t.Fatalf("unexpected order: %v", a)
}
}
// ─── intFromQuery edge cases ─────────────────────────────────────────
func TestIntFromQueryPresent(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=5", nil)
if got := intFromQuery(req, "limit", 10); got != 5 {
t.Fatalf("expected 5, got %d", got)
}
}
func TestIntFromQueryMissing(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if got := intFromQuery(req, "limit", 10); got != 10 {
t.Fatalf("expected default 10, got %d", got)
}
}
func TestIntFromQueryInvalid(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=abc", nil)
if got := intFromQuery(req, "limit", 10); got != 10 {
t.Fatalf("expected default 10 for invalid, got %d", got)
}
}
func TestIntFromQueryNegative(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=-3", nil)
if got := intFromQuery(req, "limit", 10); got != -3 {
t.Fatalf("expected -3, got %d", got)
}
}
func TestIntFromQueryZero(t *testing.T) {
req := httptest.NewRequest("GET", "/?limit=0", nil)
if got := intFromQuery(req, "limit", 10); got != 0 {
t.Fatalf("expected 0, got %d", got)
}
}
// ─── nilIfEmpty ──────────────────────────────────────────────────────
func TestNilIfEmptyEmpty(t *testing.T) {
if nilIfEmpty("") != nil {
t.Fatal("expected nil for empty string")
}
}
func TestNilIfEmptyNonEmpty(t *testing.T) {
if nilIfEmpty("hello") != "hello" {
t.Fatal("expected 'hello'")
}
}
// ─── nilIfZero ───────────────────────────────────────────────────────
func TestNilIfZeroZero(t *testing.T) {
if nilIfZero(0) != nil {
t.Fatal("expected nil for zero")
}
}
func TestNilIfZeroNonZero(t *testing.T) {
if nilIfZero(42) != int64(42) {
t.Fatal("expected 42")
}
}
func TestNilIfZeroNegative(t *testing.T) {
if nilIfZero(-1) != int64(-1) {
t.Fatal("expected -1")
}
}
// ─── toStringSlice ───────────────────────────────────────────────────
func TestToStringSliceFromAnySlice(t *testing.T) {
input := []any{"a", "b", "c"}
got, ok := toStringSlice(input)
if !ok || len(got) != 3 {
t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok)
}
if got[0] != "a" || got[1] != "b" || got[2] != "c" {
t.Fatalf("unexpected values: %#v", got)
}
}
func TestToStringSliceFromMixed(t *testing.T) {
input := []any{"hello", 42, true}
got, ok := toStringSlice(input)
if !ok {
t.Fatal("expected ok for mixed types")
}
if got[0] != "hello" || got[1] != "42" || got[2] != "true" {
t.Fatalf("unexpected values: %#v", got)
}
}
func TestToStringSliceFromNonSlice(t *testing.T) {
_, ok := toStringSlice("not a slice")
if ok {
t.Fatal("expected not ok for string input")
}
}
func TestToStringSliceFromNil(t *testing.T) {
_, ok := toStringSlice(nil)
if ok {
t.Fatal("expected not ok for nil input")
}
}
func TestToStringSliceEmpty(t *testing.T) {
got, ok := toStringSlice([]any{})
if !ok {
t.Fatal("expected ok for empty slice")
}
if len(got) != 0 {
t.Fatalf("expected empty result, got %#v", got)
}
}
func TestToStringSliceTrimsWhitespace(t *testing.T) {
got, ok := toStringSlice([]any{" hello ", " world "})
if !ok {
t.Fatal("expected ok")
}
if got[0] != "hello" || got[1] != "world" {
t.Fatalf("expected trimmed values, got %#v", got)
}
}
// ─── toAccount edge cases ────────────────────────────────────────────
func TestToAccountAllFields(t *testing.T) {
acc := toAccount(map[string]any{
"email": "user@test.com",
"mobile": "13800138000",
"password": "secret",
"token": "tok123",
})
if acc.Email != "user@test.com" {
t.Fatalf("unexpected email: %q", acc.Email)
}
if acc.Mobile != "13800138000" {
t.Fatalf("unexpected mobile: %q", acc.Mobile)
}
if acc.Password != "secret" {
t.Fatalf("unexpected password: %q", acc.Password)
}
if acc.Token != "tok123" {
t.Fatalf("unexpected token: %q", acc.Token)
}
}
func TestToAccountNumericValues(t *testing.T) {
acc := toAccount(map[string]any{
"email": 12345,
})
if acc.Email != "12345" {
t.Fatalf("expected numeric converted to string, got %q", acc.Email)
}
}
// ─── fieldString edge cases ──────────────────────────────────────────
func TestFieldStringNonString(t *testing.T) {
got := fieldString(map[string]any{"key": 42}, "key")
if got != "42" {
t.Fatalf("expected '42' for int, got %q", got)
}
}
func TestFieldStringBool(t *testing.T) {
got := fieldString(map[string]any{"key": true}, "key")
if got != "true" {
t.Fatalf("expected 'true', got %q", got)
}
}
func TestFieldStringWhitespace(t *testing.T) {
got := fieldString(map[string]any{"key": " hello "}, "key")
if got != "hello" {
t.Fatalf("expected trimmed 'hello', got %q", got)
}
}
// ─── statusOr ────────────────────────────────────────────────────────
func TestStatusOrZeroReturnsDefault(t *testing.T) {
if got := statusOr(0, http.StatusOK); got != http.StatusOK {
t.Fatalf("expected %d, got %d", http.StatusOK, got)
}
}
func TestStatusOrNonZeroReturnsValue(t *testing.T) {
if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest {
t.Fatalf("expected %d, got %d", http.StatusBadRequest, got)
}
}

View File

@@ -0,0 +1,375 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
"ds2api/internal/account"
"ds2api/internal/config"
)
// ─── extractCallerToken edge cases ───────────────────────────────────
func TestExtractCallerTokenBearerPrefix(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer my-token")
if got := extractCallerToken(req); got != "my-token" {
t.Fatalf("expected my-token, got %q", got)
}
}
func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "BEARER My-Token")
if got := extractCallerToken(req); got != "My-Token" {
t.Fatalf("expected My-Token, got %q", got)
}
}
func TestExtractCallerTokenBearerEmpty(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer ")
if got := extractCallerToken(req); got != "" {
t.Fatalf("expected empty for 'Bearer ', got %q", got)
}
}
func TestExtractCallerTokenXAPIKey(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("x-api-key", "x-api-key-token")
if got := extractCallerToken(req); got != "x-api-key-token" {
t.Fatalf("expected x-api-key-token, got %q", got)
}
}
func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer bearer-token")
req.Header.Set("x-api-key", "x-api-key-token")
if got := extractCallerToken(req); got != "bearer-token" {
t.Fatalf("expected bearer-token, got %q", got)
}
}
func TestExtractCallerTokenMissingHeaders(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
if got := extractCallerToken(req); got != "" {
t.Fatalf("expected empty for missing headers, got %q", got)
}
}
func TestExtractCallerTokenNonBearerAuth(t *testing.T) {
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Basic abc123")
if got := extractCallerToken(req); got != "" {
t.Fatalf("expected empty for Basic auth, got %q", got)
}
}
// ─── Context helpers ─────────────────────────────────────────────────
func TestWithAuthAndFromContext(t *testing.T) {
a := &RequestAuth{DeepSeekToken: "test-token"}
ctx := WithAuth(context.Background(), a)
got, ok := FromContext(ctx)
if !ok || got.DeepSeekToken != "test-token" {
t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken)
}
}
func TestFromContextMissing(t *testing.T) {
_, ok := FromContext(context.Background())
if ok {
t.Fatal("expected not ok from empty context")
}
}
// ─── RefreshToken edge cases ─────────────────────────────────────────
func TestRefreshTokenNotConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false, resolver: r}
if r.RefreshToken(context.Background(), a) {
t.Fatal("expected false for non-config token")
}
}
func TestRefreshTokenEmptyAccountID(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r}
if r.RefreshToken(context.Background(), a) {
t.Fatal("expected false for empty account ID")
}
}
func TestRefreshTokenSuccess(t *testing.T) {
r := newTestResolver(t)
// First acquire an account
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
if !r.RefreshToken(context.Background(), a) {
t.Fatal("expected refresh to succeed")
}
if a.DeepSeekToken != "fresh-token" {
t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken)
}
}
// ─── MarkTokenInvalid edge cases ─────────────────────────────────────
func TestMarkTokenInvalidNotConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r}
r.MarkTokenInvalid(a)
// Should not panic, token should be unchanged for non-config
if a.DeepSeekToken != "" {
// Actually it does clear it; that's fine - let's check behavior
}
}
func TestMarkTokenInvalidEmptyAccountID(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r}
r.MarkTokenInvalid(a)
// Should not panic
}
func TestMarkTokenInvalidClearsToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
r.MarkTokenInvalid(a)
if a.DeepSeekToken != "" {
t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken)
}
if a.Account.Token != "" {
t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token)
}
}
// ─── SwitchAccount edge cases ────────────────────────────────────────
func TestSwitchAccountNotConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false, resolver: r}
if r.SwitchAccount(context.Background(), a) {
t.Fatal("expected false for non-config token")
}
}
func TestSwitchAccountNilTriedAccounts(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"acc1@test.com","token":"t1"},
{"email":"acc2@test.com","token":"t2"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "new-token", nil
})
// First acquire
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
oldID := a.AccountID
a.TriedAccounts = nil // test nil initialization in SwitchAccount
if !r.SwitchAccount(context.Background(), a) {
t.Fatal("expected switch to succeed")
}
if a.AccountID == oldID {
t.Fatalf("expected different account after switch")
}
r.Release(a)
}
// ─── Release edge cases ─────────────────────────────────────────────
func TestReleaseNilAuth(t *testing.T) {
r := newTestResolver(t)
r.Release(nil) // should not panic
}
func TestReleaseNonConfigToken(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: false}
r.Release(a) // should not panic
}
func TestReleaseEmptyAccountID(t *testing.T) {
r := newTestResolver(t)
a := &RequestAuth{UseConfigToken: true, AccountID: ""}
r.Release(a) // should not panic
}
// ─── JWT edge cases ──────────────────────────────────────────────────
func TestVerifyJWTInvalidFormat(t *testing.T) {
_, err := VerifyJWT("not-a-jwt")
if err == nil {
t.Fatal("expected error for invalid JWT format")
}
}
func TestVerifyJWTInvalidSignature(t *testing.T) {
token, _ := CreateJWT(1)
// Tamper with the signature
parts := splitJWT(token)
if len(parts) == 3 {
tampered := parts[0] + "." + parts[1] + ".invalid_signature"
_, err := VerifyJWT(tampered)
if err == nil {
t.Fatal("expected error for tampered signature")
}
}
}
func TestVerifyJWTExpired(t *testing.T) {
// Create a token with 0 hours expiry - will use default, so we can't easily test
// Instead test with bad payload
_, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid")
if err == nil {
t.Fatal("expected error for expired/invalid JWT")
}
}
func TestCreateJWTDefaultExpiry(t *testing.T) {
token, err := CreateJWT(0) // should use default
if err != nil {
t.Fatalf("create jwt failed: %v", err)
}
_, err = VerifyJWT(token)
if err != nil {
t.Fatalf("verify jwt failed: %v", err)
}
}
// ─── VerifyAdminRequest edge cases ───────────────────────────────────
func TestVerifyAdminRequestNoHeader(t *testing.T) {
req, _ := http.NewRequest("GET", "/admin/config", nil)
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for missing auth")
}
}
func TestVerifyAdminRequestEmptyBearer(t *testing.T) {
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Bearer ")
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for empty bearer")
}
}
func TestVerifyAdminRequestWithAdminKey(t *testing.T) {
t.Setenv("DS2API_ADMIN_KEY", "test-admin-key")
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Bearer test-admin-key")
if err := VerifyAdminRequest(req); err != nil {
t.Fatalf("expected admin key accepted: %v", err)
}
}
func TestVerifyAdminRequestInvalidCredentials(t *testing.T) {
t.Setenv("DS2API_ADMIN_KEY", "correct-key")
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Bearer wrong-key")
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for wrong key")
}
}
func TestVerifyAdminRequestBasicAuth(t *testing.T) {
req, _ := http.NewRequest("GET", "/admin/config", nil)
req.Header.Set("Authorization", "Basic abc123")
if err := VerifyAdminRequest(req); err == nil {
t.Fatal("expected error for Basic auth")
}
}
// ─── Determine with login failure ────────────────────────────────────
func TestDetermineWithLoginFailure(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[{"email":"acc@test.com","password":"pwd"}]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "", errors.New("login failed")
})
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
_, err := r.Determine(req)
if err == nil {
t.Fatal("expected error when login fails")
}
}
// ─── Determine with target account ───────────────────────────────────
func TestDetermineWithTargetAccount(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[
{"email":"acc1@test.com","token":"t1"},
{"email":"acc2@test.com","token":"t2"}
]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "fresh-token", nil
})
req, _ := http.NewRequest("POST", "/", nil)
req.Header.Set("Authorization", "Bearer managed-key")
req.Header.Set("X-Ds2-Target-Account", "acc2@test.com")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
if a.AccountID != "acc2@test.com" {
t.Fatalf("expected target account acc2, got %q", a.AccountID)
}
}
// helper
func splitJWT(token string) []string {
result := make([]string, 0, 3)
start := 0
count := 0
for i := 0; i < len(token); i++ {
if token[i] == '.' {
result = append(result, token[start:i])
start = i + 1
count++
}
}
result = append(result, token[start:])
return result
}

View File

@@ -0,0 +1,445 @@
package config
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
)
// ─── GetModelConfig edge cases ───────────────────────────────────────
func TestGetModelConfigDeepSeekChat(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-chat")
if !ok {
t.Fatal("expected ok for deepseek-chat")
}
if thinking || search {
t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigDeepSeekReasoner(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-reasoner")
if !ok {
t.Fatal("expected ok for deepseek-reasoner")
}
if !thinking || search {
t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigDeepSeekChatSearch(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-chat-search")
if !ok {
t.Fatal("expected ok for deepseek-chat-search")
}
if thinking || !search {
t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) {
thinking, search, ok := GetModelConfig("deepseek-reasoner-search")
if !ok {
t.Fatal("expected ok for deepseek-reasoner-search")
}
if !thinking || !search {
t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search)
}
}
func TestGetModelConfigCaseInsensitive(t *testing.T) {
thinking, search, ok := GetModelConfig("DeepSeek-Chat")
if !ok {
t.Fatal("expected ok for case-insensitive deepseek-chat")
}
if thinking || search {
t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat")
}
}
func TestGetModelConfigUnknownModel(t *testing.T) {
_, _, ok := GetModelConfig("gpt-4")
if ok {
t.Fatal("expected not ok for unknown model")
}
}
func TestGetModelConfigEmpty(t *testing.T) {
_, _, ok := GetModelConfig("")
if ok {
t.Fatal("expected not ok for empty model")
}
}
// ─── lower function ──────────────────────────────────────────────────
func TestLowerFunction(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"Hello", "hello"},
{"ALLCAPS", "allcaps"},
{"already-lower", "already-lower"},
{"Mixed-CASE-123", "mixed-case-123"},
{"", ""},
}
for _, tc := range tests {
got := lower(tc.input)
if got != tc.expected {
t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected)
}
}
}
// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ────────────────────
func TestConfigJSONRoundtrip(t *testing.T) {
cfg := Config{
Keys: []string{"key1", "key2"},
Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}},
ClaudeMapping: map[string]string{
"fast": "deepseek-chat",
"slow": "deepseek-reasoner",
},
VercelSyncHash: "hash123",
VercelSyncTime: 1234567890,
AdditionalFields: map[string]any{
"custom_field": "custom_value",
},
}
data, err := cfg.MarshalJSON()
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var decoded Config
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" {
t.Fatalf("unexpected keys: %#v", decoded.Keys)
}
if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" {
t.Fatalf("unexpected accounts: %#v", decoded.Accounts)
}
if decoded.ClaudeMapping["fast"] != "deepseek-chat" {
t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping)
}
if decoded.VercelSyncHash != "hash123" {
t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash)
}
if decoded.AdditionalFields["custom_field"] != "custom_value" {
t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields)
}
}
func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) {
raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}`
var cfg Config
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if cfg.AdditionalFields["my_custom_field"] != "hello" {
t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields)
}
// number_field should also be preserved
if cfg.AdditionalFields["number_field"] != float64(42) {
t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"])
}
}
// ─── Config.Clone ────────────────────────────────────────────────────
func TestConfigCloneIsDeepCopy(t *testing.T) {
cfg := Config{
Keys: []string{"key1"},
Accounts: []Account{{Email: "user@test.com", Token: "token"}},
ClaudeMapping: map[string]string{
"fast": "deepseek-chat",
},
AdditionalFields: map[string]any{"custom": "value"},
}
cloned := cfg.Clone()
// Modify original
cfg.Keys[0] = "modified"
cfg.Accounts[0].Email = "modified@test.com"
cfg.ClaudeMapping["fast"] = "modified-model"
// Cloned should not be affected
if cloned.Keys[0] != "key1" {
t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys)
}
if cloned.Accounts[0].Email != "user@test.com" {
t.Fatalf("clone accounts was affected: %#v", cloned.Accounts)
}
if cloned.ClaudeMapping["fast"] != "deepseek-chat" {
t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping)
}
}
func TestConfigCloneNilMaps(t *testing.T) {
cfg := Config{
Keys: []string{"k"},
Accounts: nil,
}
cloned := cfg.Clone()
if len(cloned.Keys) != 1 {
t.Fatalf("unexpected keys length: %d", len(cloned.Keys))
}
if cloned.Accounts != nil {
t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts)
}
}
// ─── Account.Identifier edge cases ───────────────────────────────────
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
acc := Account{Mobile: "13800138000", Token: "tok"}
if acc.Identifier() != "13800138000" {
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
}
}
func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) {
acc := Account{Email: "user@test.com", Mobile: "13800138000"}
if acc.Identifier() != "user@test.com" {
t.Fatalf("expected email identifier, got %q", acc.Identifier())
}
}
func TestAccountIdentifierEmptyAccount(t *testing.T) {
acc := Account{}
if acc.Identifier() != "" {
t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier())
}
}
// ─── normalizeConfigInput ────────────────────────────────────────────
func TestNormalizeConfigInputStripsQuotes(t *testing.T) {
got := normalizeConfigInput(`"base64:abc"`)
if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) {
t.Fatalf("expected quotes stripped, got %q", got)
}
}
func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) {
got := normalizeConfigInput("'some-value'")
if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") {
t.Fatalf("expected single quotes stripped, got %q", got)
}
}
func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) {
got := normalizeConfigInput(" hello ")
if got != "hello" {
t.Fatalf("expected trimmed, got %q", got)
}
}
// ─── parseConfigString edge cases ────────────────────────────────────
func TestParseConfigStringPlainJSON(t *testing.T) {
cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" {
t.Fatalf("unexpected keys: %#v", cfg.Keys)
}
}
func TestParseConfigStringBase64Prefix(t *testing.T) {
rawJSON := `{"keys":["base64-key"],"accounts":[]}`
b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON))
cfg, err := parseConfigString("base64:" + b64)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" {
t.Fatalf("unexpected keys: %#v", cfg.Keys)
}
}
func TestParseConfigStringInvalidBase64(t *testing.T) {
_, err := parseConfigString("base64:!!!invalid!!!")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestParseConfigStringEmptyString(t *testing.T) {
_, err := parseConfigString("")
if err == nil {
t.Fatal("expected error for empty string")
}
}
// ─── Store methods ───────────────────────────────────────────────────
func TestStoreSnapshotReturnsClone(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`)
store := LoadStore()
snap := store.Snapshot()
snap.Keys[0] = "modified"
if store.Keys()[0] != "k1" {
t.Fatal("snapshot modification should not affect store")
}
}
func TestStoreHasAPIKeyMultipleKeys(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`)
store := LoadStore()
if !store.HasAPIKey("key1") {
t.Fatal("expected key1 found")
}
if !store.HasAPIKey("key2") {
t.Fatal("expected key2 found")
}
if !store.HasAPIKey("key3") {
t.Fatal("expected key3 found")
}
if store.HasAPIKey("nonexistent") {
t.Fatal("expected nonexistent key not found")
}
}
func TestStoreFindAccountNotFound(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`)
store := LoadStore()
_, ok := store.FindAccount("nonexistent@test.com")
if ok {
t.Fatal("expected account not found")
}
}
func TestStoreIsEnvBacked(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
if !store.IsEnvBacked() {
t.Fatal("expected env-backed store")
}
}
func TestStoreReplace(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
newCfg := Config{
Keys: []string{"new-key"},
Accounts: []Account{{Email: "new@test.com"}},
}
if err := store.Replace(newCfg); err != nil {
t.Fatalf("replace error: %v", err)
}
if !store.HasAPIKey("new-key") {
t.Fatal("expected new key after replace")
}
if store.HasAPIKey("k1") {
t.Fatal("expected old key removed after replace")
}
}
func TestStoreUpdate(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
err := store.Update(func(cfg *Config) error {
cfg.Keys = append(cfg.Keys, "k2")
return nil
})
if err != nil {
t.Fatalf("update error: %v", err)
}
if !store.HasAPIKey("k2") {
t.Fatal("expected k2 after update")
}
}
func TestStoreClaudeMapping(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
store := LoadStore()
mapping := store.ClaudeMapping()
if mapping["fast"] != "deepseek-chat" {
t.Fatalf("unexpected fast mapping: %q", mapping["fast"])
}
if mapping["slow"] != "deepseek-reasoner" {
t.Fatalf("unexpected slow mapping: %q", mapping["slow"])
}
}
func TestStoreClaudeMappingEmpty(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
store := LoadStore()
mapping := store.ClaudeMapping()
// Even without config mapping, there are defaults
if mapping == nil {
t.Fatal("expected non-nil mapping (may contain defaults)")
}
}
func TestStoreSetVercelSync(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
store := LoadStore()
if err := store.SetVercelSync("hash123", 1234567890); err != nil {
t.Fatalf("setVercelSync error: %v", err)
}
snap := store.Snapshot()
if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 {
t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime)
}
}
func TestStoreExportJSONAndBase64(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`)
store := LoadStore()
jsonStr, b64Str, err := store.ExportJSONAndBase64()
if err != nil {
t.Fatalf("export error: %v", err)
}
if !strings.Contains(jsonStr, "export-key") {
t.Fatalf("expected JSON to contain key: %q", jsonStr)
}
decoded, err := base64.StdEncoding.DecodeString(b64Str)
if err != nil {
t.Fatalf("base64 decode error: %v", err)
}
if !strings.Contains(string(decoded), "export-key") {
t.Fatalf("expected base64-decoded to contain key: %q", string(decoded))
}
}
// ─── OpenAIModelsResponse / ClaudeModelsResponse ─────────────────────
func TestOpenAIModelsResponse(t *testing.T) {
resp := OpenAIModelsResponse()
if resp["object"] != "list" {
t.Fatalf("unexpected object: %v", resp["object"])
}
data, ok := resp["data"].([]ModelInfo)
if !ok {
t.Fatalf("unexpected data type: %T", resp["data"])
}
if len(data) == 0 {
t.Fatal("expected non-empty models list")
}
}
func TestClaudeModelsResponse(t *testing.T) {
resp := ClaudeModelsResponse()
if resp["object"] != "list" {
t.Fatalf("unexpected object: %v", resp["object"])
}
data, ok := resp["data"].([]ModelInfo)
if !ok {
t.Fatalf("unexpected data type: %T", resp["data"])
}
if len(data) == 0 {
t.Fatal("expected non-empty models list")
}
}

View File

@@ -0,0 +1,165 @@
package deepseek
import (
"context"
"testing"
)
// ─── toFloat64 edge cases ────────────────────────────────────────────
func TestToFloat64FromFloat64(t *testing.T) {
if got := toFloat64(float64(3.14), 0); got != 3.14 {
t.Fatalf("expected 3.14, got %f", got)
}
}
func TestToFloat64FromInt(t *testing.T) {
if got := toFloat64(42, 0); got != 42.0 {
t.Fatalf("expected 42.0, got %f", got)
}
}
func TestToFloat64FromInt64(t *testing.T) {
if got := toFloat64(int64(100), 0); got != 100.0 {
t.Fatalf("expected 100.0, got %f", got)
}
}
func TestToFloat64FromStringDefault(t *testing.T) {
if got := toFloat64("42", 99.0); got != 99.0 {
t.Fatalf("expected default 99.0, got %f", got)
}
}
func TestToFloat64FromNilDefault(t *testing.T) {
if got := toFloat64(nil, 5.5); got != 5.5 {
t.Fatalf("expected default 5.5, got %f", got)
}
}
func TestToFloat64FromBoolDefault(t *testing.T) {
if got := toFloat64(true, 1.0); got != 1.0 {
t.Fatalf("expected default 1.0, got %f", got)
}
}
// ─── toInt64 edge cases ──────────────────────────────────────────────
func TestToInt64FromFloat64(t *testing.T) {
if got := toInt64(float64(42.9), 0); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestToInt64FromInt(t *testing.T) {
if got := toInt64(42, 0); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestToInt64FromInt64(t *testing.T) {
if got := toInt64(int64(100), 0); got != 100 {
t.Fatalf("expected 100, got %d", got)
}
}
func TestToInt64FromStringDefault(t *testing.T) {
if got := toInt64("42", 99); got != 99 {
t.Fatalf("expected default 99, got %d", got)
}
}
func TestToInt64FromNilDefault(t *testing.T) {
if got := toInt64(nil, 7); got != 7 {
t.Fatalf("expected default 7, got %d", got)
}
}
// ─── BuildPowHeader edge cases ───────────────────────────────────────
func TestBuildPowHeaderBasicChallenge(t *testing.T) {
challenge := map[string]any{
"algorithm": "DeepSeekHashV1",
"challenge": "abc123",
"salt": "salt456",
"signature": "sig789",
"target_path": "/path",
}
result, err := BuildPowHeader(challenge, 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == "" {
t.Fatal("expected non-empty result")
}
}
func TestBuildPowHeaderEmptyChallenge(t *testing.T) {
result, err := BuildPowHeader(map[string]any{}, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should produce a base64 encoded JSON with nil values
if result == "" {
t.Fatal("expected non-empty result for empty challenge")
}
}
// ─── PowSolver pool size ─────────────────────────────────────────────
func TestPowPoolSizeFromEnvDefault(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "")
got := powPoolSizeFromEnv()
if got < 1 {
t.Fatalf("expected positive default pool size, got %d", got)
}
}
func TestPowPoolSizeFromEnvInvalid(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "abc")
got := powPoolSizeFromEnv()
if got < 1 {
t.Fatalf("expected positive default for invalid, got %d", got)
}
}
func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "5")
got := powPoolSizeFromEnv()
if got != 5 {
t.Fatalf("expected 5, got %d", got)
}
}
// ─── NewClient ───────────────────────────────────────────────────────
func TestNewClientInitialState(t *testing.T) {
client := NewClient(nil, nil)
if client.powSolver == nil {
t.Fatal("expected powSolver to be initialized")
}
}
func TestNewClientPreloadPowIdempotent(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "1")
client := NewClient(nil, nil)
if err := client.PreloadPow(context.Background()); err != nil {
t.Fatalf("first preload failed: %v", err)
}
if err := client.PreloadPow(context.Background()); err != nil {
t.Fatalf("second preload failed: %v", err)
}
}
// ─── PowSolver init and module pool ──────────────────────────────────
func TestPowSolverPoolSizeMatchesEnv(t *testing.T) {
t.Setenv("DS2API_POW_POOL_SIZE", "2")
solver := NewPowSolver("test.wasm")
if err := solver.init(context.Background()); err != nil {
t.Fatalf("init failed: %v", err)
}
if cap(solver.pool) != 2 {
t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool))
}
}

View File

@@ -0,0 +1,140 @@
package sse
import (
"io"
"net/http"
"strings"
"testing"
)
// ─── CollectStream edge cases ────────────────────────────────────────
func makeHTTPResponse(body string) *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func TestCollectStreamEmpty(t *testing.T) {
resp := makeHTTPResponse("")
result := CollectStream(resp, false, false)
if result.Text != "" || result.Thinking != "" {
t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking)
}
}
func TestCollectStreamTextOnly(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\" World\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, false, false)
if result.Text != "Hello World" {
t.Fatalf("expected 'Hello World', got %q", result.Text)
}
if result.Thinking != "" {
t.Fatalf("expected no thinking, got %q", result.Thinking)
}
}
func TestCollectStreamThinkingAndText(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "Thinking..." {
t.Fatalf("expected 'Thinking...', got %q", result.Thinking)
}
if result.Text != "Answer" {
t.Fatalf("expected 'Answer', got %q", result.Text)
}
}
func TestCollectStreamOnlyThinking(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "Only thinking" {
t.Fatalf("expected 'Only thinking', got %q", result.Thinking)
}
if result.Text != "" {
t.Fatalf("expected empty text, got %q", result.Text)
}
}
func TestCollectStreamSkipsInvalidLines(t *testing.T) {
resp := makeHTTPResponse(
"event: comment\n" +
"data: invalid_json\n" +
"data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, false, false)
if result.Text != "valid" {
t.Fatalf("expected 'valid', got %q", result.Text)
}
}
func TestCollectStreamWithFragments(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" +
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "Think" {
t.Fatalf("expected 'Think' thinking, got %q", result.Thinking)
}
if result.Text != "Done" {
t.Fatalf("expected 'Done' text, got %q", result.Text)
}
}
func TestCollectStreamWithCitation(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\" more\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, false, false)
// CollectStream does NOT filter citations (that's done by the adapters)
// So citations are passed through as-is
if !strings.Contains(result.Text, "[citation:1]") {
t.Fatalf("expected citations to be passed through, got %q", result.Text)
}
if result.Text != "Hello[citation:1] cited text more" {
t.Fatalf("expected full text with citation, got %q", result.Text)
}
}
func TestCollectStreamMultipleThinkingChunks(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" +
"data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" +
"data: [DONE]\n",
)
result := CollectStream(resp, true, true)
if result.Thinking != "part1 part2" {
t.Fatalf("expected 'part1 part2', got %q", result.Thinking)
}
}
func TestCollectStreamStatusFinished(t *testing.T) {
resp := makeHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
"data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n",
)
result := CollectStream(resp, false, false)
if result.Text != "Hello" {
t.Fatalf("expected 'Hello', got %q", result.Text)
}
}

View File

@@ -0,0 +1,70 @@
package sse
import "testing"
func TestParseDeepSeekContentLineNotParsed(t *testing.T) {
res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text")
if res.Parsed {
t.Fatal("expected not parsed")
}
if res.NextType != "text" {
t.Fatalf("expected nextType preserved, got %q", res.NextType)
}
}
func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking")
if !res.Parsed || res.Stop {
t.Fatalf("expected parsed non-stop: %#v", res)
}
if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" {
t.Fatalf("unexpected parts: %#v", res.Parts)
}
}
func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) {
res := ParseDeepSeekContentLine(
[]byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`),
true, "thinking",
)
if !res.Parsed || res.Stop {
t.Fatalf("expected parsed non-stop: %#v", res)
}
if res.NextType != "text" {
t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType)
}
}
func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text")
if !res.ContentFilter {
t.Fatal("expected content filter flag")
}
if res.ErrorMessage == "" {
t.Fatal("expected error message on content filter")
}
}
func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text")
if !res.Parsed || !res.Stop {
t.Fatalf("expected parsed stop: %#v", res)
}
if res.ErrorMessage == "" {
t.Fatal("expected non-empty error message")
}
}
func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) {
res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text")
if res.Parsed {
t.Fatal("expected not parsed for broken JSON")
}
}
func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) {
res := ParseDeepSeekContentLine([]byte{}, false, "text")
if res.Parsed {
t.Fatal("expected not parsed for empty bytes")
}
}

View File

@@ -0,0 +1,631 @@
package sse
import "testing"
// ─── ParseDeepSeekSSELine edge cases ─────────────────────────────────
func TestParseDeepSeekSSELineEmptyLine(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte(""))
if ok {
t.Fatal("expected not parsed for empty line")
}
}
func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte("event: message"))
if ok {
t.Fatal("expected not parsed for non-data line")
}
}
func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json"))
if ok {
t.Fatal("expected not parsed for invalid JSON")
}
}
func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) {
_, _, ok := ParseDeepSeekSSELine([]byte(" "))
if ok {
t.Fatal("expected not parsed for whitespace-only line")
}
}
func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) {
chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `))
if !ok || done {
t.Fatalf("expected parsed chunk for spaced data line")
}
if chunk["v"] != "hello" {
t.Fatalf("unexpected chunk: %#v", chunk)
}
}
// ─── shouldSkipPath edge cases ───────────────────────────────────────
func TestShouldSkipPathQuasiStatus(t *testing.T) {
if !shouldSkipPath("response/quasi_status") {
t.Fatal("expected skip for quasi_status path")
}
}
func TestShouldSkipPathElapsedSecs(t *testing.T) {
if !shouldSkipPath("response/elapsed_secs") {
t.Fatal("expected skip for elapsed_secs path")
}
}
func TestShouldSkipPathTokenUsage(t *testing.T) {
if !shouldSkipPath("response/token_usage") {
t.Fatal("expected skip for token_usage path")
}
}
func TestShouldSkipPathPendingFragment(t *testing.T) {
if !shouldSkipPath("response/pending_fragment") {
t.Fatal("expected skip for pending_fragment path")
}
}
func TestShouldSkipPathConversationMode(t *testing.T) {
if !shouldSkipPath("response/conversation_mode") {
t.Fatal("expected skip for conversation_mode path")
}
}
func TestShouldSkipPathSearchStatus(t *testing.T) {
if !shouldSkipPath("response/search_status") {
t.Fatal("expected skip for search_status path")
}
}
func TestShouldSkipPathFragmentStatus(t *testing.T) {
if !shouldSkipPath("response/fragments/-1/status") {
t.Fatal("expected skip for fragment -1 status")
}
if !shouldSkipPath("response/fragments/-2/status") {
t.Fatal("expected skip for fragment -2 status")
}
if !shouldSkipPath("response/fragments/-3/status") {
t.Fatal("expected skip for fragment -3 status")
}
}
func TestShouldSkipPathRegularContent(t *testing.T) {
if shouldSkipPath("response/content") {
t.Fatal("expected not skip for content path")
}
if shouldSkipPath("response/thinking_content") {
t.Fatal("expected not skip for thinking_content path")
}
}
// ─── ParseSSEChunkForContent edge cases ──────────────────────────────
func TestParseSSEChunkForContentNoVField(t *testing.T) {
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text")
if finished {
t.Fatal("expected not finished")
}
if len(parts) != 0 {
t.Fatalf("expected no parts when v is missing, got %#v", parts)
}
if nextType != "text" {
t.Fatalf("expected type preserved, got %q", nextType)
}
}
func TestParseSSEChunkForContentSkippedPath(t *testing.T) {
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{
"p": "response/token_usage",
"v": "some data",
}, false, "text")
if finished || len(parts) > 0 {
t.Fatalf("expected skipped path to produce no output")
}
if nextType != "text" {
t.Fatalf("expected type preserved for skipped path")
}
}
func TestParseSSEChunkForContentFinishedStatus(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "response/status",
"v": "FINISHED",
}, false, "text")
if !finished {
t.Fatal("expected finished on status FINISHED")
}
if len(parts) != 0 {
t.Fatalf("expected no parts on finished, got %d", len(parts))
}
}
func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "response/status",
"v": "IN_PROGRESS",
}, false, "text")
if finished {
t.Fatal("expected not finished for non-FINISHED status")
}
if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" {
t.Fatalf("expected content for non-FINISHED status, got %#v", parts)
}
}
func TestParseSSEChunkForContentEmptyStringV(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "response/content",
"v": "",
}, false, "text")
if finished {
t.Fatal("expected not finished")
}
if len(parts) != 0 {
t.Fatalf("expected no parts for empty string v, got %#v", parts)
}
}
func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "",
"v": "FINISHED",
}, false, "text")
if !finished {
t.Fatal("expected finished on empty path with FINISHED value")
}
if len(parts) != 0 {
t.Fatalf("expected no parts on finished")
}
}
func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) {
_, finished, _ := ParseSSEChunkForContent(map[string]any{
"p": "status",
"v": "FINISHED",
}, false, "text")
if !finished {
t.Fatal("expected finished on status path with FINISHED value")
}
}
func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) {
parts, _, nextType := ParseSSEChunkForContent(map[string]any{
"v": "some thought",
}, true, "thinking")
if len(parts) != 1 || parts[0].Type != "thinking" {
t.Fatalf("expected thinking part on empty path, got %#v", parts)
}
if nextType != "thinking" {
t.Fatalf("expected nextType thinking, got %q", nextType)
}
}
func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) {
parts, _, nextType := ParseSSEChunkForContent(map[string]any{
"v": "text content",
}, true, "text")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text part when currentType=text, got %#v", parts)
}
if nextType != "text" {
t.Fatalf("expected nextType text, got %q", nextType)
}
}
// ─── ParseSSEChunkForContent: fragments path with THINK type ─────────
func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "THINK",
"content": "深入思考...",
},
},
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text")
if finished {
t.Fatal("expected not finished")
}
if nextType != "thinking" {
t.Fatalf("expected nextType thinking, got %q", nextType)
}
if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." {
t.Fatalf("unexpected parts: %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "RESPONSE",
"content": "",
},
},
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking")
if finished {
t.Fatal("expected not finished")
}
if nextType != "text" {
t.Fatalf("expected nextType text, got %q", nextType)
}
if len(parts) != 0 {
t.Fatalf("expected no parts for empty content, got %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "UNKNOWN",
"content": "some text",
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, true, "text")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for unknown fragment type, got %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": "not an array",
}
parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text")
if finished {
t.Fatal("expected not finished")
}
// "not an array" should be treated as string value at the end
if len(parts) != 1 || parts[0].Text != "not an array" {
t.Fatalf("unexpected parts: %#v", parts)
}
}
func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments",
"o": "APPEND",
"v": []any{"string item"},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
// Non-map items in fragment array are skipped; the []any itself is handled later
_ = parts // just checking it doesn't panic
}
// ─── ParseSSEChunkForContent: response path with nested fragment ─────
func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) {
chunk := map[string]any{
"p": "response",
"v": []any{
map[string]any{
"p": "fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "THINKING",
},
},
},
},
}
_, _, nextType := ParseSSEChunkForContent(chunk, true, "text")
if nextType != "thinking" {
t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType)
}
}
func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) {
chunk := map[string]any{
"p": "response",
"v": []any{
map[string]any{
"p": "fragments",
"o": "APPEND",
"v": []any{
map[string]any{
"type": "RESPONSE",
},
},
},
},
}
_, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking")
if nextType != "text" {
t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType)
}
}
// ─── ParseSSEChunkForContent: map value with wrapped response ────────
func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"response": map[string]any{
"fragments": []any{
map[string]any{
"type": "THINK",
"content": "思考...",
},
map[string]any{
"type": "RESPONSE",
"content": "回答...",
},
},
},
},
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text")
if finished {
t.Fatal("expected not finished")
}
if nextType != "text" {
t.Fatalf("expected nextType text after RESPONSE, got %q", nextType)
}
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts)
}
if parts[0].Type != "thinking" || parts[0].Text != "思考..." {
t.Fatalf("first part mismatch: %#v", parts[0])
}
if parts[1].Type != "text" || parts[1].Text != "回答..." {
t.Fatalf("second part mismatch: %#v", parts[1])
}
}
func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"fragments": []any{
map[string]any{
"type": "RESPONSE",
"content": "直接回答",
},
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" {
t.Fatalf("unexpected parts for direct fragments: %#v", parts)
}
}
func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"fragments": []any{
map[string]any{
"type": "CUSTOM",
"content": "custom content",
},
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected partType fallback for unknown type, got %#v", parts)
}
}
func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) {
chunk := map[string]any{
"v": map[string]any{
"fragments": []any{
map[string]any{
"type": "RESPONSE",
"content": "",
},
},
},
}
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for empty fragment content, got %#v", parts)
}
}
// ─── ParseSSEChunkForContent: fragments/-1/content path ──────────────
func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) {
chunk := map[string]any{
"p": "response/fragments/-1/content",
"v": "继续思考",
}
parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking")
if len(parts) != 1 || parts[0].Type != "thinking" {
t.Fatalf("expected inherited thinking type, got %#v", parts)
}
}
// ─── IsCitation edge cases ───────────────────────────────────────────
func TestIsCitationWithLeadingWhitespace(t *testing.T) {
if !IsCitation(" [citation:2] text") {
t.Fatal("expected citation true with leading whitespace")
}
}
func TestIsCitationEmpty(t *testing.T) {
if IsCitation("") {
t.Fatal("expected citation false for empty string")
}
}
func TestIsCitationSimilarPrefix(t *testing.T) {
if IsCitation("[cite:1] text") {
t.Fatal("expected citation false for [cite: prefix")
}
}
// ─── extractContentRecursive edge cases ──────────────────────────────
func TestExtractContentRecursiveFinishedStatus(t *testing.T) {
items := []any{
map[string]any{"p": "status", "v": "FINISHED"},
}
parts, finished := extractContentRecursive(items, "text")
if !finished {
t.Fatal("expected finished on status FINISHED")
}
if len(parts) != 0 {
t.Fatalf("expected no parts, got %#v", parts)
}
}
func TestExtractContentRecursiveSkipsPath(t *testing.T) {
items := []any{
map[string]any{"p": "token_usage", "v": "data"},
}
parts, finished := extractContentRecursive(items, "text")
if finished {
t.Fatal("expected not finished")
}
if len(parts) != 0 {
t.Fatalf("expected no parts for skipped path, got %#v", parts)
}
}
func TestExtractContentRecursiveContentField(t *testing.T) {
items := []any{
map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" {
t.Fatalf("unexpected parts: %#v", parts)
}
}
func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) {
items := []any{
map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 1 || parts[0].Type != "thinking" {
t.Fatalf("expected thinking type for THINK content, got %#v", parts)
}
}
func TestExtractContentRecursiveThinkingPath(t *testing.T) {
items := []any{
map[string]any{"p": "thinking_content", "v": "deep thought"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" {
t.Fatalf("unexpected parts for thinking path: %#v", parts)
}
}
func TestExtractContentRecursiveContentPath(t *testing.T) {
items := []any{
map[string]any{"p": "content", "v": "text content"},
}
parts, _ := extractContentRecursive(items, "thinking")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for content path, got %#v", parts)
}
}
func TestExtractContentRecursiveResponsePath(t *testing.T) {
items := []any{
map[string]any{"p": "response", "v": "text content"},
}
parts, _ := extractContentRecursive(items, "thinking")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for response path, got %#v", parts)
}
}
func TestExtractContentRecursiveFragmentsPath(t *testing.T) {
items := []any{
map[string]any{"p": "fragments", "v": "fragment text"},
}
parts, _ := extractContentRecursive(items, "thinking")
if len(parts) != 1 || parts[0].Type != "text" {
t.Fatalf("expected text type for fragments path, got %#v", parts)
}
}
func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) {
items := []any{
map[string]any{
"p": "fragments",
"v": []any{
map[string]any{"content": "thought", "type": "THINKING"},
map[string]any{"content": "answer", "type": "RESPONSE"},
"raw string",
},
},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts)
}
if parts[0].Type != "thinking" || parts[0].Text != "thought" {
t.Fatalf("first part mismatch: %#v", parts[0])
}
if parts[1].Type != "text" || parts[1].Text != "answer" {
t.Fatalf("second part mismatch: %#v", parts[1])
}
if parts[2].Type != "text" || parts[2].Text != "raw string" {
t.Fatalf("third part mismatch: %#v", parts[2])
}
}
func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) {
items := []any{
map[string]any{
"p": "fragments",
"v": []any{
map[string]any{"content": "", "type": "RESPONSE"},
},
},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for empty nested content, got %#v", parts)
}
}
func TestExtractContentRecursiveFinishedString(t *testing.T) {
items := []any{
map[string]any{"p": "content", "v": "FINISHED"},
}
parts, _ := extractContentRecursive(items, "text")
// "FINISHED" string value on non-status path should be skipped
if len(parts) != 0 {
t.Fatalf("expected FINISHED string to be skipped, got %#v", parts)
}
}
func TestExtractContentRecursiveNoVField(t *testing.T) {
items := []any{
map[string]any{"p": "content"},
}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for missing v field, got %#v", parts)
}
}
func TestExtractContentRecursiveNonMapItem(t *testing.T) {
items := []any{"just a string", 42}
parts, _ := extractContentRecursive(items, "text")
if len(parts) != 0 {
t.Fatalf("expected no parts for non-map items, got %#v", parts)
}
}

View File

@@ -0,0 +1,177 @@
package sse
import (
"context"
"io"
"strings"
"testing"
)
func TestStartParsedLinePumpEmptyBody(t *testing.T) {
body := strings.NewReader("")
results, done := StartParsedLinePump(context.Background(), body, false, "text")
collected := make([]LineResult, 0)
for r := range results {
collected = append(collected, r)
}
if err := <-done; err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(collected) != 0 {
t.Fatalf("expected no results for empty body, got %d", len(collected))
}
}
func TestStartParsedLinePumpMultipleLines(t *testing.T) {
body := strings.NewReader(
"data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"text\"}\n" +
"data: [DONE]\n",
)
results, done := StartParsedLinePump(context.Background(), body, true, "thinking")
collected := make([]LineResult, 0)
for r := range results {
collected = append(collected, r)
}
if err := <-done; err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(collected) < 3 {
t.Fatalf("expected at least 3 results, got %d", len(collected))
}
// First should be thinking
if collected[0].Parts[0].Type != "thinking" {
t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type)
}
// Last should be stop
last := collected[len(collected)-1]
if !last.Stop {
t.Fatal("expected last result to be stop")
}
}
func TestStartParsedLinePumpTypeTracking(t *testing.T) {
body := strings.NewReader(
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" +
"data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" +
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" +
"data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" +
"data: [DONE]\n",
)
results, done := StartParsedLinePump(context.Background(), body, true, "text")
types := make([]string, 0)
for r := range results {
for _, p := range r.Parts {
types = append(types, p.Type)
}
}
<-done
// Should have: thinking, thinking, text, text
expected := []string{"thinking", "thinking", "text", "text"}
if len(types) != len(expected) {
t.Fatalf("expected types %v, got %v", expected, types)
}
for i, want := range expected {
if types[i] != want {
t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types)
}
}
}
func TestStartParsedLinePumpContextCancellation(t *testing.T) {
pr, pw := io.Pipe()
ctx, cancel := context.WithCancel(context.Background())
results, done := StartParsedLinePump(ctx, pr, false, "text")
// Write one line to allow it to start
go func() {
_, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n")
// Don't close yet - wait for context cancel
}()
// Read first result
r := <-results
if !r.Parsed || len(r.Parts) == 0 {
t.Fatalf("expected first parsed result, got %#v", r)
}
// Cancel context - this will cause the pump to exit on next send
cancel()
// Close the pipe to unblock scanner.Scan()
pw.Close()
// Drain remaining results
for range results {
}
err := <-done
// Error may be context.Canceled or nil (if pipe closed first)
if err != nil && err != context.Canceled {
t.Fatalf("expected context.Canceled or nil error, got %v", err)
}
}
func TestStartParsedLinePumpOnlyDONE(t *testing.T) {
body := strings.NewReader("data: [DONE]\n")
results, done := StartParsedLinePump(context.Background(), body, false, "text")
collected := make([]LineResult, 0)
for r := range results {
collected = append(collected, r)
}
<-done
if len(collected) != 1 {
t.Fatalf("expected 1 result, got %d", len(collected))
}
if !collected[0].Stop {
t.Fatal("expected stop on [DONE]")
}
}
func TestStartParsedLinePumpNonSSELines(t *testing.T) {
body := strings.NewReader(
"event: update\n" +
": comment line\n" +
"data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" +
"data: [DONE]\n",
)
results, done := StartParsedLinePump(context.Background(), body, false, "text")
var validCount int
for r := range results {
if r.Parsed && len(r.Parts) > 0 {
validCount++
}
}
<-done
if validCount != 1 {
t.Fatalf("expected 1 valid result, got %d", validCount)
}
}
func TestStartParsedLinePumpThinkingDisabled(t *testing.T) {
body := strings.NewReader(
"data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" +
"data: {\"p\":\"response/content\",\"v\":\"response\"}\n" +
"data: [DONE]\n",
)
// With thinking disabled, thinking content should still be emitted but marked differently
results, done := StartParsedLinePump(context.Background(), body, false, "text")
var parts []ContentPart
for r := range results {
parts = append(parts, r.Parts...)
}
<-done
if len(parts) < 1 {
t.Fatalf("expected at least 1 part, got %d", len(parts))
}
}

View File

@@ -0,0 +1,441 @@
package util
import (
"encoding/json"
"net/http/httptest"
"strings"
"testing"
"ds2api/internal/config"
)
// ─── EstimateTokens edge cases ───────────────────────────────────────
func TestEstimateTokensEmpty(t *testing.T) {
if got := EstimateTokens(""); got != 0 {
t.Fatalf("expected 0 for empty string, got %d", got)
}
}
func TestEstimateTokensShortASCII(t *testing.T) {
got := EstimateTokens("ab")
if got != 1 {
t.Fatalf("expected 1 for 2 ascii chars, got %d", got)
}
}
func TestEstimateTokensLongASCII(t *testing.T) {
got := EstimateTokens(strings.Repeat("x", 100))
if got != 25 {
t.Fatalf("expected 25 for 100 ascii chars, got %d", got)
}
}
func TestEstimateTokensChinese(t *testing.T) {
got := EstimateTokens("你好世界")
if got < 1 {
t.Fatalf("expected at least 1 token for Chinese text, got %d", got)
}
}
func TestEstimateTokensMixed(t *testing.T) {
got := EstimateTokens("Hello 你好世界")
if got < 2 {
t.Fatalf("expected at least 2 tokens for mixed text, got %d", got)
}
}
func TestEstimateTokensSingleByte(t *testing.T) {
got := EstimateTokens("x")
if got != 1 {
t.Fatalf("expected 1 for single char (minimum), got %d", got)
}
}
func TestEstimateTokensSingleChinese(t *testing.T) {
got := EstimateTokens("你")
if got != 1 {
t.Fatalf("expected 1 for single Chinese char, got %d", got)
}
}
// ─── ToBool edge cases ───────────────────────────────────────────────
func TestToBoolTrue(t *testing.T) {
if !ToBool(true) {
t.Fatal("expected true")
}
}
func TestToBoolFalse(t *testing.T) {
if ToBool(false) {
t.Fatal("expected false")
}
}
func TestToBoolNonBool(t *testing.T) {
if ToBool("true") {
t.Fatal("expected false for string 'true'")
}
if ToBool(1) {
t.Fatal("expected false for int 1")
}
if ToBool(nil) {
t.Fatal("expected false for nil")
}
}
// ─── IntFrom edge cases ─────────────────────────────────────────────
func TestIntFromFloat64(t *testing.T) {
if got := IntFrom(float64(42.5)); got != 42 {
t.Fatalf("expected 42 for float64(42.5), got %d", got)
}
}
func TestIntFromInt(t *testing.T) {
if got := IntFrom(int(42)); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestIntFromInt64(t *testing.T) {
if got := IntFrom(int64(42)); got != 42 {
t.Fatalf("expected 42, got %d", got)
}
}
func TestIntFromString(t *testing.T) {
if got := IntFrom("42"); got != 0 {
t.Fatalf("expected 0 for string, got %d", got)
}
}
func TestIntFromNil(t *testing.T) {
if got := IntFrom(nil); got != 0 {
t.Fatalf("expected 0 for nil, got %d", got)
}
}
// ─── WriteJSON ───────────────────────────────────────────────────────
func TestWriteJSON(t *testing.T) {
rec := httptest.NewRecorder()
WriteJSON(rec, 200, map[string]any{"key": "value"})
if rec.Code != 200 {
t.Fatalf("expected 200, got %d", rec.Code)
}
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected application/json content type, got %q", ct)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode error: %v", err)
}
if body["key"] != "value" {
t.Fatalf("unexpected body: %#v", body)
}
}
func TestWriteJSONStatusCodes(t *testing.T) {
for _, code := range []int{200, 201, 400, 404, 500} {
rec := httptest.NewRecorder()
WriteJSON(rec, code, map[string]any{"status": code})
if rec.Code != code {
t.Fatalf("expected %d, got %d", code, rec.Code)
}
}
}
// ─── MessagesPrepare edge cases ──────────────────────────────────────
func TestMessagesPrepareEmpty(t *testing.T) {
got := MessagesPrepare(nil)
if got != "" {
t.Fatalf("expected empty for nil messages, got %q", got)
}
}
func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Hello"},
{"role": "user", "content": "World"},
}
got := MessagesPrepare(messages)
if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") {
t.Fatalf("expected both messages, got %q", got)
}
// Should be merged without <User> between them
count := strings.Count(got, "<User>")
if count != 0 {
t.Fatalf("expected no User marker for first message pair, got %d occurrences", count)
}
}
func TestMessagesPrepareAssistantMarkers(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
}
got := MessagesPrepare(messages)
if !strings.Contains(got, "<Assistant>") {
t.Fatalf("expected assistant marker, got %q", got)
}
if !strings.Contains(got, "<end▁of▁sentence>") {
t.Fatalf("expected end of sentence marker, got %q", got)
}
}
func TestMessagesPrepareUnknownRole(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Hello"},
{"role": "unknown_role", "content": "Unknown"},
}
got := MessagesPrepare(messages)
if !strings.Contains(got, "Unknown") {
t.Fatalf("expected unknown role content, got %q", got)
}
}
func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": "Look at this: ![alt](https://example.com/img.png)"},
}
got := MessagesPrepare(messages)
if strings.Contains(got, "![alt]") {
t.Fatalf("expected markdown image to be replaced, got %q", got)
}
}
func TestMessagesPrepareNilContent(t *testing.T) {
messages := []map[string]any{
{"role": "user", "content": nil},
}
got := MessagesPrepare(messages)
if got != "null" {
t.Logf("nil content handled as: %q", got)
}
}
// ─── normalizeContent edge cases ─────────────────────────────────────
func TestNormalizeContentString(t *testing.T) {
got := normalizeContent("hello")
if got != "hello" {
t.Fatalf("expected 'hello', got %q", got)
}
}
func TestNormalizeContentArray(t *testing.T) {
got := normalizeContent([]any{
map[string]any{"type": "text", "text": "line1"},
map[string]any{"type": "text", "text": "line2"},
})
if got != "line1\nline2" {
t.Fatalf("expected 'line1\\nline2', got %q", got)
}
}
func TestNormalizeContentArrayWithContentField(t *testing.T) {
got := normalizeContent([]any{
map[string]any{"type": "text", "content": "from-content"},
})
if got != "from-content" {
t.Fatalf("expected 'from-content', got %q", got)
}
}
func TestNormalizeContentArraySkipsImage(t *testing.T) {
got := normalizeContent([]any{
map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"},
map[string]any{"type": "text", "text": "caption"},
})
if strings.Contains(got, "image") {
t.Fatalf("expected image skipped, got %q", got)
}
if got != "caption" {
t.Fatalf("expected 'caption', got %q", got)
}
}
func TestNormalizeContentArrayNonMapItems(t *testing.T) {
got := normalizeContent([]any{"string item", 42})
if got != "" {
t.Fatalf("expected empty for non-map items, got %q", got)
}
}
func TestNormalizeContentJSON(t *testing.T) {
got := normalizeContent(map[string]any{"key": "value"})
if !strings.Contains(got, `"key":"value"`) {
t.Fatalf("expected JSON serialized, got %q", got)
}
}
// ─── ConvertClaudeToDeepSeek edge cases ──────────────────────────────
func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
}
out := ConvertClaudeToDeepSeek(req, store)
if out["model"] == "" {
t.Fatal("expected default model")
}
}
func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
"stop_sequences": []any{"\n\n"},
}
out := ConvertClaudeToDeepSeek(req, store)
if out["stop"] == nil {
t.Fatal("expected stop field from stop_sequences")
}
}
func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
"temperature": 0.7,
"top_p": 0.9,
}
out := ConvertClaudeToDeepSeek(req, store)
if out["temperature"] != 0.7 {
t.Fatalf("expected temperature 0.7, got %v", out["temperature"])
}
if out["top_p"] != 0.9 {
t.Fatalf("expected top_p 0.9, got %v", out["top_p"])
}
}
func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) {
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
}
out := ConvertClaudeToDeepSeek(req, store)
msgs, _ := out["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("expected 1 message without system, got %d", len(msgs))
}
}
func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
store := config.LoadStore()
req := map[string]any{
"model": "claude-opus-4-6",
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
}
out := ConvertClaudeToDeepSeek(req, store)
if out["model"] != "deepseek-reasoner" {
t.Fatalf("expected opus to use slow mapping, got %q", out["model"])
}
}
// ─── FormatOpenAIStreamToolCalls ─────────────────────────────────────
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
{Name: "search", Input: map[string]any{"q": "test"}},
})
if len(formatted) != 1 {
t.Fatalf("expected 1, got %d", len(formatted))
}
fn, _ := formatted[0]["function"].(map[string]any)
if fn["name"] != "search" {
t.Fatalf("unexpected function name: %#v", fn)
}
if formatted[0]["index"] != 0 {
t.Fatalf("expected index 0, got %v", formatted[0]["index"])
}
}
// ─── ParseToolCalls more edge cases ──────────────────────────────────
func TestParseToolCallsNoToolNames(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
calls := ParseToolCalls(text, nil)
if len(calls) != 1 {
t.Fatalf("expected 1 call with nil tool names, got %d", len(calls))
}
}
func TestParseToolCallsEmptyText(t *testing.T) {
calls := ParseToolCalls("", []string{"search"})
if len(calls) != 0 {
t.Fatalf("expected 0 calls for empty text, got %d", len(calls))
}
}
func TestParseToolCallsMultipleTools(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}`
calls := ParseToolCalls(text, []string{"search", "get_weather"})
if len(calls) != 2 {
t.Fatalf("expected 2 calls, got %d", len(calls))
}
}
func TestParseToolCallsInputAsString(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Input["q"] != "golang" {
t.Fatalf("expected parsed string input, got %#v", calls[0].Input)
}
}
func TestParseToolCallsWithFunctionWrapper(t *testing.T) {
text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}`
calls := ParseToolCalls(text, []string{"calc"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "calc" {
t.Fatalf("expected calc, got %q", calls[0].Name)
}
}
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
if len(calls) != 0 {
t.Fatalf("expected fenced code block ignored, got %d calls", len(calls))
}
}
// ─── looksLikeToolExampleContext ─────────────────────────────────────
func TestLooksLikeToolExampleContextChinese(t *testing.T) {
if !looksLikeToolExampleContext("下面是示例") {
t.Fatal("expected true for Chinese example context")
}
}
func TestLooksLikeToolExampleContextEnglish(t *testing.T) {
if !looksLikeToolExampleContext("here is an example of") {
t.Fatal("expected true for English example context")
}
}
func TestLooksLikeToolExampleContextNone(t *testing.T) {
if looksLikeToolExampleContext("I will call the tool now") {
t.Fatal("expected false for non-example context")
}
}
func TestLooksLikeToolExampleContextFenced(t *testing.T) {
if !looksLikeToolExampleContext("```json") {
t.Fatal("expected true for fenced code block context")
}
}