From f2b10992cc0ba721e40edcd58b43c57c4822ceff Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 16:52:16 +0800 Subject: [PATCH] test: Introduce comprehensive edge case tests for various internal packages including SSE, Claude, Auth, Account, Config, Deepseek, Admin, and Util. --- .gitignore | 3 + internal/account/pool_edge_test.go | 249 +++++++ internal/adapter/claude/handler_util_test.go | 348 ++++++++++ internal/adapter/openai/handler.go | 4 +- .../adapter/openai/handler_toolcall_test.go | 10 +- internal/admin/helpers_edge_test.go | 240 +++++++ internal/auth/auth_edge_test.go | 375 +++++++++++ internal/config/config_edge_test.go | 445 ++++++++++++ internal/deepseek/deepseek_edge_test.go | 165 +++++ internal/sse/consumer_edge_test.go | 140 ++++ internal/sse/line_edge_test.go | 70 ++ internal/sse/parser_edge_test.go | 631 ++++++++++++++++++ internal/sse/stream_edge_test.go | 177 +++++ internal/util/util_edge_test.go | 441 ++++++++++++ 14 files changed, 3291 insertions(+), 7 deletions(-) create mode 100644 internal/account/pool_edge_test.go create mode 100644 internal/adapter/claude/handler_util_test.go create mode 100644 internal/admin/helpers_edge_test.go create mode 100644 internal/auth/auth_edge_test.go create mode 100644 internal/config/config_edge_test.go create mode 100644 internal/deepseek/deepseek_edge_test.go create mode 100644 internal/sse/consumer_edge_test.go create mode 100644 internal/sse/line_edge_test.go create mode 100644 internal/sse/parser_edge_test.go create mode 100644 internal/sse/stream_edge_test.go create mode 100644 internal/util/util_edge_test.go diff --git a/.gitignore b/.gitignore index 5f776e2..422c203 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,9 @@ ds2api-tests htmlcov/ .pytest_cache/ .tox/ +*.coverprofile +coverage*.out +cover/ # Misc *.pyc diff --git a/internal/account/pool_edge_test.go b/internal/account/pool_edge_test.go new file mode 100644 index 0000000..6e90823 --- /dev/null +++ b/internal/account/pool_edge_test.go @@ -0,0 +1,249 @@ +package account + +import ( + "context" + "sync" + "testing" + "time" + + "ds2api/internal/config" +) + +// ─── Pool edge cases ───────────────────────────────────────────────── + +func TestPoolEmptyNoAccounts(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + pool := NewPool(config.LoadStore()) + if _, ok := pool.Acquire("", nil); ok { + t.Fatal("expected acquire to fail with no accounts") + } + status := pool.Status() + if total, ok := status["total"].(int); !ok || total != 0 { + t.Fatalf("unexpected total: %#v", status["total"]) + } +} + +func TestPoolReleaseNonExistentAccount(t *testing.T) { + pool := newPoolForTest(t, "2") + pool.Release("nonexistent@example.com") // should not panic +} + +func TestPoolReleaseAlreadyReleased(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected acquire success") + } + pool.Release(acc.Identifier()) + pool.Release(acc.Identifier()) // double release should not panic +} + +func TestPoolAcquireTargetNotFound(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("nonexistent@example.com", nil); ok { + t.Fatal("expected acquire to fail for non-existent target") + } +} + +func TestPoolAcquireWithExclusionList(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true}) + if !ok { + t.Fatal("expected acquire success with exclusion") + } + if acc.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier()) + } + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireAllExcluded(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("", map[string]bool{ + "acc1@example.com": true, + "acc2@example.com": true, + }); ok { + t.Fatal("expected acquire to fail when all accounts excluded") + } +} + +func TestPoolStatusFields(t *testing.T) { + pool := newPoolForTest(t, "2") + status := pool.Status() + + // Check all expected fields are present + for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} { + if _, ok := status[key]; !ok { + t.Fatalf("missing status field: %s", key) + } + } +} + +func TestPoolStatusAccountDetails(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, _ := pool.Acquire("acc1@example.com", nil) + + status := pool.Status() + inUseAccounts, ok := status["in_use_accounts"].([]string) + if !ok { + t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"]) + } + found := false + for _, id := range inUseAccounts { + if id == "acc1@example.com" { + found = true + break + } + } + if !found { + t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts) + } + if status["in_use"] != 1 { + t.Fatalf("expected 1 in_use, got %v", status["in_use"]) + } + + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireWaitContextCancelled(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + // Exhaust the pool + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire to succeed") + } + + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + var waitOK bool + go func() { + defer wg.Done() + _, waitOK = pool.AcquireWait(ctx, "", nil) + }() + + // Wait until queued + waitForWaitingCount(t, pool, 1) + + // Cancel context + cancel() + + wg.Wait() + if waitOK { + t.Fatal("expected acquire to fail after context cancellation") + } + + pool.Release(first.Identifier()) +} + +func TestPoolAcquireWaitTargetAccount(t *testing.T) { + pool := newPoolForTest(t, "1") + // Exhaust acc1 + acc1, ok := pool.Acquire("acc1@example.com", nil) + if !ok { + t.Fatal("expected acquire acc1 success") + } + + // Acquire acc2 directly (should succeed since acc2 is free) + ctx := context.Background() + acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil) + if !ok { + t.Fatal("expected acquire acc2 success via AcquireWait") + } + if acc2.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2, got %q", acc2.Identifier()) + } + + pool.Release(acc1.Identifier()) + pool.Release(acc2.Identifier()) +} + +func TestPoolMaxQueueSizeOverride(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 5 { + t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"]) + } +} + +func TestPoolQueueSizeAliasEnv(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 7 { + t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"]) + } +} + +func TestPoolMultipleAcquireReleaseCycles(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + for i := 0; i < 10; i++ { + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatalf("acquire failed at cycle %d", i) + } + pool.Release(acc.Identifier()) + } +} + +func TestPoolConcurrentAcquireWait(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire success") + } + + const waiters = 3 + results := make(chan bool, waiters) + + for i := 0; i < waiters; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, ok := pool.AcquireWait(ctx, "", nil) + results <- ok + }() + } + + // Wait for all to be queued (only 1 can queue) + time.Sleep(50 * time.Millisecond) + + // Release and allow queued requests to proceed + pool.Release(first.Identifier()) + + successCount := 0 + timeoutCount := 0 + for i := 0; i < waiters; i++ { + select { + case ok := <-results: + if ok { + successCount++ + // Release for next waiter + pool.Release("acc1@example.com") + } else { + timeoutCount++ + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for results") + } + } + + // At least 1 should succeed; 2 may fail due to queue limit + if successCount < 1 { + t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount) + } +} diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go new file mode 100644 index 0000000..73d2fab --- /dev/null +++ b/internal/adapter/claude/handler_util_test.go @@ -0,0 +1,348 @@ +package claude + +import ( + "testing" +) + +// ─── normalizeClaudeMessages ───────────────────────────────────────── + +func TestNormalizeClaudeMessagesSimpleString(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected 1 message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["content"] != "Hello" { + t.Fatalf("expected 'Hello', got %v", m["content"]) + } +} + +func TestNormalizeClaudeMessagesArrayContent(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "line1\nline2" { + t.Fatalf("expected joined text, got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesToolResult(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "tool_result", "content": "tool output"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "tool output" { + t.Fatalf("expected 'tool output', got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { + msgs := []any{"not a map", 42} + got := normalizeClaudeMessages(msgs) + if len(got) != 0 { + t.Fatalf("expected 0 messages for non-map items, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesEmpty(t *testing.T) { + got := normalizeClaudeMessages(nil) + if len(got) != 0 { + t.Fatalf("expected 0, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) { + msgs := []any{ + map[string]any{"role": "assistant", "content": "response"}, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected 'assistant', got %q", m["role"]) + } +} + +func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Hello"}, + map[string]any{"type": "image", "source": "data:..."}, + map[string]any{"type": "text", "text": "World"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "Hello\nWorld" { + t.Fatalf("expected only text parts joined, got %q", m["content"]) + } +} + +// ─── buildClaudeToolPrompt ─────────────────────────────────────────── + +func TestBuildClaudeToolPromptSingleTool(t *testing.T) { + tools := []any{ + map[string]any{ + "name": "search", + "description": "Search the web", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + } + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt") + } + // Should contain tool name and description + if !containsStr(prompt, "search") { + t.Fatalf("expected 'search' in prompt") + } + if !containsStr(prompt, "Search the web") { + t.Fatalf("expected description in prompt") + } + if !containsStr(prompt, "tool_calls") { + t.Fatalf("expected tool_calls instruction in prompt") + } +} + +func TestBuildClaudeToolPromptMultipleTools(t *testing.T) { + tools := []any{ + map[string]any{"name": "tool1", "description": "desc1"}, + map[string]any{"name": "tool2", "description": "desc2"}, + } + prompt := buildClaudeToolPrompt(tools) + if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") { + t.Fatalf("expected both tools in prompt") + } +} + +func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) { + tools := []any{"not a map"} + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt even with invalid tools") + } + // Should still contain the intro and instruction + if !containsStr(prompt, "You are Claude") { + t.Fatalf("expected intro in prompt") + } +} + +// ─── hasSystemMessage ──────────────────────────────────────────────── + +func TestHasSystemMessageTrue(t *testing.T) { + msgs := []any{ + map[string]any{"role": "system", "content": "You are a helper"}, + map[string]any{"role": "user", "content": "Hi"}, + } + if !hasSystemMessage(msgs) { + t.Fatal("expected true") + } +} + +func TestHasSystemMessageFalse(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hi"}, + map[string]any{"role": "assistant", "content": "Hello"}, + } + if hasSystemMessage(msgs) { + t.Fatal("expected false") + } +} + +func TestHasSystemMessageEmpty(t *testing.T) { + if hasSystemMessage(nil) { + t.Fatal("expected false for nil") + } +} + +func TestHasSystemMessageNonMap(t *testing.T) { + msgs := []any{"not a map"} + if hasSystemMessage(msgs) { + t.Fatal("expected false for non-map") + } +} + +// ─── extractClaudeToolNames ────────────────────────────────────────── + +func TestExtractClaudeToolNamesSingle(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "search" { + t.Fatalf("expected [search], got %v", names) + } +} + +func TestExtractClaudeToolNamesMultiple(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + map[string]any{"name": "calculate"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 2 { + t.Fatalf("expected 2 names, got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) { + tools := []any{ + map[string]any{"name": ""}, + map[string]any{"name": "valid"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "valid" { + t.Fatalf("expected [valid], got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) { + tools := []any{"not a map", 42} + names := extractClaudeToolNames(tools) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +func TestExtractClaudeToolNamesNil(t *testing.T) { + names := extractClaudeToolNames(nil) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +// ─── toMessageMaps ─────────────────────────────────────────────────── + +func TestToMessageMapsNormal(t *testing.T) { + input := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1, got %d", len(got)) + } +} + +func TestToMessageMapsNonSlice(t *testing.T) { + got := toMessageMaps("not a slice") + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestToMessageMapsSkipsNonMap(t *testing.T) { + input := []any{"string", map[string]any{"role": "user"}, 42} + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1 map, got %d", len(got)) + } +} + +func TestToMessageMapsNil(t *testing.T) { + got := toMessageMaps(nil) + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +// ─── extractMessageContent ────────────────────────────────────────── + +func TestExtractMessageContentString(t *testing.T) { + if got := extractMessageContent("hello"); got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestExtractMessageContentArray(t *testing.T) { + input := []any{"part1", "part2"} + got := extractMessageContent(input) + if got != "part1\npart2" { + t.Fatalf("expected joined, got %q", got) + } +} + +func TestExtractMessageContentOther(t *testing.T) { + got := extractMessageContent(42) + if got != "42" { + t.Fatalf("expected '42', got %q", got) + } +} + +func TestExtractMessageContentNil(t *testing.T) { + got := extractMessageContent(nil) + if got != "" { + t.Fatalf("expected '', got %q", got) + } +} + +// ─── cloneMap ──────────────────────────────────────────────────────── + +func TestCloneMapBasic(t *testing.T) { + original := map[string]any{"a": 1, "b": "hello"} + clone := cloneMap(original) + original["a"] = 999 + if clone["a"] != 1 { + t.Fatalf("expected 1, got %v", clone["a"]) + } + if clone["b"] != "hello" { + t.Fatalf("expected 'hello', got %v", clone["b"]) + } +} + +func TestCloneMapEmpty(t *testing.T) { + clone := cloneMap(map[string]any{}) + if len(clone) != 0 { + t.Fatalf("expected empty, got %v", clone) + } +} + +func TestCloneMapNested(t *testing.T) { + // cloneMap is shallow, so nested maps share references + inner := map[string]any{"key": "value"} + original := map[string]any{"nested": inner} + clone := cloneMap(original) + // Shallow clone means inner is shared + inner["key"] = "modified" + cloneNested := clone["nested"].(map[string]any) + if cloneNested["key"] != "modified" { + t.Fatal("expected shallow clone to share nested references") + } +} + +// helper +func containsStr(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub)) +} + +func findSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 962e450..4de28b7 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -120,10 +120,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, toolNames) } -func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 8c1435d..3cab68c 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -128,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -161,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -189,7 +189,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -220,7 +220,7 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -249,7 +249,7 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } diff --git a/internal/admin/helpers_edge_test.go b/internal/admin/helpers_edge_test.go new file mode 100644 index 0000000..2a0bf20 --- /dev/null +++ b/internal/admin/helpers_edge_test.go @@ -0,0 +1,240 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "ds2api/internal/config" +) + +// ─── reverseAccounts ───────────────────────────────────────────────── + +func TestReverseAccountsEmpty(t *testing.T) { + a := []config.Account{} + reverseAccounts(a) + if len(a) != 0 { + t.Fatal("expected empty") + } +} + +func TestReverseAccountsTwoElements(t *testing.T) { + a := []config.Account{ + {Email: "a@test.com"}, + {Email: "b@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" { + t.Fatalf("unexpected order after reverse: %v", a) + } +} + +func TestReverseAccountsThreeElements(t *testing.T) { + a := []config.Account{ + {Email: "1@test.com"}, + {Email: "2@test.com"}, + {Email: "3@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" { + t.Fatalf("unexpected order: %v", a) + } +} + +// ─── intFromQuery edge cases ───────────────────────────────────────── + +func TestIntFromQueryPresent(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=5", nil) + if got := intFromQuery(req, "limit", 10); got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +func TestIntFromQueryMissing(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10, got %d", got) + } +} + +func TestIntFromQueryInvalid(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=abc", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10 for invalid, got %d", got) + } +} + +func TestIntFromQueryNegative(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=-3", nil) + if got := intFromQuery(req, "limit", 10); got != -3 { + t.Fatalf("expected -3, got %d", got) + } +} + +func TestIntFromQueryZero(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=0", nil) + if got := intFromQuery(req, "limit", 10); got != 0 { + t.Fatalf("expected 0, got %d", got) + } +} + +// ─── nilIfEmpty ────────────────────────────────────────────────────── + +func TestNilIfEmptyEmpty(t *testing.T) { + if nilIfEmpty("") != nil { + t.Fatal("expected nil for empty string") + } +} + +func TestNilIfEmptyNonEmpty(t *testing.T) { + if nilIfEmpty("hello") != "hello" { + t.Fatal("expected 'hello'") + } +} + +// ─── nilIfZero ─────────────────────────────────────────────────────── + +func TestNilIfZeroZero(t *testing.T) { + if nilIfZero(0) != nil { + t.Fatal("expected nil for zero") + } +} + +func TestNilIfZeroNonZero(t *testing.T) { + if nilIfZero(42) != int64(42) { + t.Fatal("expected 42") + } +} + +func TestNilIfZeroNegative(t *testing.T) { + if nilIfZero(-1) != int64(-1) { + t.Fatal("expected -1") + } +} + +// ─── toStringSlice ─────────────────────────────────────────────────── + +func TestToStringSliceFromAnySlice(t *testing.T) { + input := []any{"a", "b", "c"} + got, ok := toStringSlice(input) + if !ok || len(got) != 3 { + t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok) + } + if got[0] != "a" || got[1] != "b" || got[2] != "c" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromMixed(t *testing.T) { + input := []any{"hello", 42, true} + got, ok := toStringSlice(input) + if !ok { + t.Fatal("expected ok for mixed types") + } + if got[0] != "hello" || got[1] != "42" || got[2] != "true" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromNonSlice(t *testing.T) { + _, ok := toStringSlice("not a slice") + if ok { + t.Fatal("expected not ok for string input") + } +} + +func TestToStringSliceFromNil(t *testing.T) { + _, ok := toStringSlice(nil) + if ok { + t.Fatal("expected not ok for nil input") + } +} + +func TestToStringSliceEmpty(t *testing.T) { + got, ok := toStringSlice([]any{}) + if !ok { + t.Fatal("expected ok for empty slice") + } + if len(got) != 0 { + t.Fatalf("expected empty result, got %#v", got) + } +} + +func TestToStringSliceTrimsWhitespace(t *testing.T) { + got, ok := toStringSlice([]any{" hello ", " world "}) + if !ok { + t.Fatal("expected ok") + } + if got[0] != "hello" || got[1] != "world" { + t.Fatalf("expected trimmed values, got %#v", got) + } +} + +// ─── toAccount edge cases ──────────────────────────────────────────── + +func TestToAccountAllFields(t *testing.T) { + acc := toAccount(map[string]any{ + "email": "user@test.com", + "mobile": "13800138000", + "password": "secret", + "token": "tok123", + }) + if acc.Email != "user@test.com" { + t.Fatalf("unexpected email: %q", acc.Email) + } + if acc.Mobile != "13800138000" { + t.Fatalf("unexpected mobile: %q", acc.Mobile) + } + if acc.Password != "secret" { + t.Fatalf("unexpected password: %q", acc.Password) + } + if acc.Token != "tok123" { + t.Fatalf("unexpected token: %q", acc.Token) + } +} + +func TestToAccountNumericValues(t *testing.T) { + acc := toAccount(map[string]any{ + "email": 12345, + }) + if acc.Email != "12345" { + t.Fatalf("expected numeric converted to string, got %q", acc.Email) + } +} + +// ─── fieldString edge cases ────────────────────────────────────────── + +func TestFieldStringNonString(t *testing.T) { + got := fieldString(map[string]any{"key": 42}, "key") + if got != "42" { + t.Fatalf("expected '42' for int, got %q", got) + } +} + +func TestFieldStringBool(t *testing.T) { + got := fieldString(map[string]any{"key": true}, "key") + if got != "true" { + t.Fatalf("expected 'true', got %q", got) + } +} + +func TestFieldStringWhitespace(t *testing.T) { + got := fieldString(map[string]any{"key": " hello "}, "key") + if got != "hello" { + t.Fatalf("expected trimmed 'hello', got %q", got) + } +} + +// ─── statusOr ──────────────────────────────────────────────────────── + +func TestStatusOrZeroReturnsDefault(t *testing.T) { + if got := statusOr(0, http.StatusOK); got != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, got) + } +} + +func TestStatusOrNonZeroReturnsValue(t *testing.T) { + if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest { + t.Fatalf("expected %d, got %d", http.StatusBadRequest, got) + } +} diff --git a/internal/auth/auth_edge_test.go b/internal/auth/auth_edge_test.go new file mode 100644 index 0000000..55c46ef --- /dev/null +++ b/internal/auth/auth_edge_test.go @@ -0,0 +1,375 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +// ─── extractCallerToken edge cases ─────────────────────────────────── + +func TestExtractCallerTokenBearerPrefix(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer my-token") + if got := extractCallerToken(req); got != "my-token" { + t.Fatalf("expected my-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "BEARER My-Token") + if got := extractCallerToken(req); got != "My-Token" { + t.Fatalf("expected My-Token, got %q", got) + } +} + +func TestExtractCallerTokenBearerEmpty(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer ") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for 'Bearer ', got %q", got) + } +} + +func TestExtractCallerTokenXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "x-api-key-token" { + t.Fatalf("expected x-api-key-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer bearer-token") + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "bearer-token" { + t.Fatalf("expected bearer-token, got %q", got) + } +} + +func TestExtractCallerTokenMissingHeaders(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for missing headers, got %q", got) + } +} + +func TestExtractCallerTokenNonBearerAuth(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Basic abc123") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for Basic auth, got %q", got) + } +} + +// ─── Context helpers ───────────────────────────────────────────────── + +func TestWithAuthAndFromContext(t *testing.T) { + a := &RequestAuth{DeepSeekToken: "test-token"} + ctx := WithAuth(context.Background(), a) + got, ok := FromContext(ctx) + if !ok || got.DeepSeekToken != "test-token" { + t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken) + } +} + +func TestFromContextMissing(t *testing.T) { + _, ok := FromContext(context.Background()) + if ok { + t.Fatal("expected not ok from empty context") + } +} + +// ─── RefreshToken edge cases ───────────────────────────────────────── + +func TestRefreshTokenNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestRefreshTokenEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for empty account ID") + } +} + +func TestRefreshTokenSuccess(t *testing.T) { + r := newTestResolver(t) + // First acquire an account + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + if !r.RefreshToken(context.Background(), a) { + t.Fatal("expected refresh to succeed") + } + if a.DeepSeekToken != "fresh-token" { + t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken) + } +} + +// ─── MarkTokenInvalid edge cases ───────────────────────────────────── + +func TestMarkTokenInvalidNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic, token should be unchanged for non-config + if a.DeepSeekToken != "" { + // Actually it does clear it; that's fine - let's check behavior + } +} + +func TestMarkTokenInvalidEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic +} + +func TestMarkTokenInvalidClearsToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + r.MarkTokenInvalid(a) + if a.DeepSeekToken != "" { + t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken) + } + if a.Account.Token != "" { + t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token) + } +} + +// ─── SwitchAccount edge cases ──────────────────────────────────────── + +func TestSwitchAccountNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.SwitchAccount(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestSwitchAccountNilTriedAccounts(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "new-token", nil + }) + + // First acquire + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + + oldID := a.AccountID + a.TriedAccounts = nil // test nil initialization in SwitchAccount + if !r.SwitchAccount(context.Background(), a) { + t.Fatal("expected switch to succeed") + } + if a.AccountID == oldID { + t.Fatalf("expected different account after switch") + } + r.Release(a) +} + +// ─── Release edge cases ───────────────────────────────────────────── + +func TestReleaseNilAuth(t *testing.T) { + r := newTestResolver(t) + r.Release(nil) // should not panic +} + +func TestReleaseNonConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false} + r.Release(a) // should not panic +} + +func TestReleaseEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: ""} + r.Release(a) // should not panic +} + +// ─── JWT edge cases ────────────────────────────────────────────────── + +func TestVerifyJWTInvalidFormat(t *testing.T) { + _, err := VerifyJWT("not-a-jwt") + if err == nil { + t.Fatal("expected error for invalid JWT format") + } +} + +func TestVerifyJWTInvalidSignature(t *testing.T) { + token, _ := CreateJWT(1) + // Tamper with the signature + parts := splitJWT(token) + if len(parts) == 3 { + tampered := parts[0] + "." + parts[1] + ".invalid_signature" + _, err := VerifyJWT(tampered) + if err == nil { + t.Fatal("expected error for tampered signature") + } + } +} + +func TestVerifyJWTExpired(t *testing.T) { + // Create a token with 0 hours expiry - will use default, so we can't easily test + // Instead test with bad payload + _, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid") + if err == nil { + t.Fatal("expected error for expired/invalid JWT") + } +} + +func TestCreateJWTDefaultExpiry(t *testing.T) { + token, err := CreateJWT(0) // should use default + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + _, err = VerifyJWT(token) + if err != nil { + t.Fatalf("verify jwt failed: %v", err) + } +} + +// ─── VerifyAdminRequest edge cases ─────────────────────────────────── + +func TestVerifyAdminRequestNoHeader(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for missing auth") + } +} + +func TestVerifyAdminRequestEmptyBearer(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer ") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for empty bearer") + } +} + +func TestVerifyAdminRequestWithAdminKey(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "test-admin-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer test-admin-key") + if err := VerifyAdminRequest(req); err != nil { + t.Fatalf("expected admin key accepted: %v", err) + } +} + +func TestVerifyAdminRequestInvalidCredentials(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "correct-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer wrong-key") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for wrong key") + } +} + +func TestVerifyAdminRequestBasicAuth(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Basic abc123") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for Basic auth") + } +} + +// ─── Determine with login failure ──────────────────────────────────── + +func TestDetermineWithLoginFailure(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@test.com","password":"pwd"}] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "", errors.New("login failed") + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + _, err := r.Determine(req) + if err == nil { + t.Fatal("expected error when login fails") + } +} + +// ─── Determine with target account ─────────────────────────────────── + +func TestDetermineWithTargetAccount(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "fresh-token", nil + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + req.Header.Set("X-Ds2-Target-Account", "acc2@test.com") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if a.AccountID != "acc2@test.com" { + t.Fatalf("expected target account acc2, got %q", a.AccountID) + } +} + +// helper +func splitJWT(token string) []string { + result := make([]string, 0, 3) + start := 0 + count := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + result = append(result, token[start:i]) + start = i + 1 + count++ + } + } + result = append(result, token[start:]) + return result +} diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go new file mode 100644 index 0000000..81cc7ec --- /dev/null +++ b/internal/config/config_edge_test.go @@ -0,0 +1,445 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" +) + +// ─── GetModelConfig edge cases ─────────────────────────────────────── + +func TestGetModelConfigDeepSeekChat(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat") + if !ok { + t.Fatal("expected ok for deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasoner(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner") + if !ok { + t.Fatal("expected ok for deepseek-reasoner") + } + if !thinking || search { + t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekChatSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat-search") + if !ok { + t.Fatal("expected ok for deepseek-chat-search") + } + if thinking || !search { + t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner-search") + if !ok { + t.Fatal("expected ok for deepseek-reasoner-search") + } + if !thinking || !search { + t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigCaseInsensitive(t *testing.T) { + thinking, search, ok := GetModelConfig("DeepSeek-Chat") + if !ok { + t.Fatal("expected ok for case-insensitive deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat") + } +} + +func TestGetModelConfigUnknownModel(t *testing.T) { + _, _, ok := GetModelConfig("gpt-4") + if ok { + t.Fatal("expected not ok for unknown model") + } +} + +func TestGetModelConfigEmpty(t *testing.T) { + _, _, ok := GetModelConfig("") + if ok { + t.Fatal("expected not ok for empty model") + } +} + +// ─── lower function ────────────────────────────────────────────────── + +func TestLowerFunction(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello", "hello"}, + {"ALLCAPS", "allcaps"}, + {"already-lower", "already-lower"}, + {"Mixed-CASE-123", "mixed-case-123"}, + {"", ""}, + } + for _, tc := range tests { + got := lower(tc.input) + if got != tc.expected { + t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ──────────────────── + +func TestConfigJSONRoundtrip(t *testing.T) { + cfg := Config{ + Keys: []string{"key1", "key2"}, + Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + }, + VercelSyncHash: "hash123", + VercelSyncTime: 1234567890, + AdditionalFields: map[string]any{ + "custom_field": "custom_value", + }, + } + + data, err := cfg.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var decoded Config + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" { + t.Fatalf("unexpected keys: %#v", decoded.Keys) + } + if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" { + t.Fatalf("unexpected accounts: %#v", decoded.Accounts) + } + if decoded.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping) + } + if decoded.VercelSyncHash != "hash123" { + t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash) + } + if decoded.AdditionalFields["custom_field"] != "custom_value" { + t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields) + } +} + +func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) { + raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}` + var cfg Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if cfg.AdditionalFields["my_custom_field"] != "hello" { + t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields) + } + // number_field should also be preserved + if cfg.AdditionalFields["number_field"] != float64(42) { + t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"]) + } +} + +// ─── Config.Clone ──────────────────────────────────────────────────── + +func TestConfigCloneIsDeepCopy(t *testing.T) { + cfg := Config{ + Keys: []string{"key1"}, + Accounts: []Account{{Email: "user@test.com", Token: "token"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + }, + AdditionalFields: map[string]any{"custom": "value"}, + } + + cloned := cfg.Clone() + + // Modify original + cfg.Keys[0] = "modified" + cfg.Accounts[0].Email = "modified@test.com" + cfg.ClaudeMapping["fast"] = "modified-model" + + // Cloned should not be affected + if cloned.Keys[0] != "key1" { + t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys) + } + if cloned.Accounts[0].Email != "user@test.com" { + t.Fatalf("clone accounts was affected: %#v", cloned.Accounts) + } + if cloned.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping) + } +} + +func TestConfigCloneNilMaps(t *testing.T) { + cfg := Config{ + Keys: []string{"k"}, + Accounts: nil, + } + cloned := cfg.Clone() + if len(cloned.Keys) != 1 { + t.Fatalf("unexpected keys length: %d", len(cloned.Keys)) + } + if cloned.Accounts != nil { + t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts) + } +} + +// ─── Account.Identifier edge cases ─────────────────────────────────── + +func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) { + acc := Account{Mobile: "13800138000", Token: "tok"} + if acc.Identifier() != "13800138000" { + t.Fatalf("expected mobile identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) { + acc := Account{Email: "user@test.com", Mobile: "13800138000"} + if acc.Identifier() != "user@test.com" { + t.Fatalf("expected email identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierEmptyAccount(t *testing.T) { + acc := Account{} + if acc.Identifier() != "" { + t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier()) + } +} + +// ─── normalizeConfigInput ──────────────────────────────────────────── + +func TestNormalizeConfigInputStripsQuotes(t *testing.T) { + got := normalizeConfigInput(`"base64:abc"`) + if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) { + t.Fatalf("expected quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) { + got := normalizeConfigInput("'some-value'") + if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") { + t.Fatalf("expected single quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) { + got := normalizeConfigInput(" hello ") + if got != "hello" { + t.Fatalf("expected trimmed, got %q", got) + } +} + +// ─── parseConfigString edge cases ──────────────────────────────────── + +func TestParseConfigStringPlainJSON(t *testing.T) { + cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["base64-key"],"accounts":[]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString("base64:" + b64) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringInvalidBase64(t *testing.T) { + _, err := parseConfigString("base64:!!!invalid!!!") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestParseConfigStringEmptyString(t *testing.T) { + _, err := parseConfigString("") + if err == nil { + t.Fatal("expected error for empty string") + } +} + +// ─── Store methods ─────────────────────────────────────────────────── + +func TestStoreSnapshotReturnsClone(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`) + store := LoadStore() + snap := store.Snapshot() + snap.Keys[0] = "modified" + if store.Keys()[0] != "k1" { + t.Fatal("snapshot modification should not affect store") + } +} + +func TestStoreHasAPIKeyMultipleKeys(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`) + store := LoadStore() + if !store.HasAPIKey("key1") { + t.Fatal("expected key1 found") + } + if !store.HasAPIKey("key2") { + t.Fatal("expected key2 found") + } + if !store.HasAPIKey("key3") { + t.Fatal("expected key3 found") + } + if store.HasAPIKey("nonexistent") { + t.Fatal("expected nonexistent key not found") + } +} + +func TestStoreFindAccountNotFound(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`) + store := LoadStore() + _, ok := store.FindAccount("nonexistent@test.com") + if ok { + t.Fatal("expected account not found") + } +} + +func TestStoreIsEnvBacked(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.IsEnvBacked() { + t.Fatal("expected env-backed store") + } +} + +func TestStoreReplace(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + newCfg := Config{ + Keys: []string{"new-key"}, + Accounts: []Account{{Email: "new@test.com"}}, + } + if err := store.Replace(newCfg); err != nil { + t.Fatalf("replace error: %v", err) + } + if !store.HasAPIKey("new-key") { + t.Fatal("expected new key after replace") + } + if store.HasAPIKey("k1") { + t.Fatal("expected old key removed after replace") + } +} + +func TestStoreUpdate(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + err := store.Update(func(cfg *Config) error { + cfg.Keys = append(cfg.Keys, "k2") + return nil + }) + if err != nil { + t.Fatalf("update error: %v", err) + } + if !store.HasAPIKey("k2") { + t.Fatal("expected k2 after update") + } +} + +func TestStoreClaudeMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := LoadStore() + mapping := store.ClaudeMapping() + if mapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected fast mapping: %q", mapping["fast"]) + } + if mapping["slow"] != "deepseek-reasoner" { + t.Fatalf("unexpected slow mapping: %q", mapping["slow"]) + } +} + +func TestStoreClaudeMappingEmpty(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + mapping := store.ClaudeMapping() + // Even without config mapping, there are defaults + if mapping == nil { + t.Fatal("expected non-nil mapping (may contain defaults)") + } +} + +func TestStoreSetVercelSync(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + if err := store.SetVercelSync("hash123", 1234567890); err != nil { + t.Fatalf("setVercelSync error: %v", err) + } + snap := store.Snapshot() + if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 { + t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime) + } +} + +func TestStoreExportJSONAndBase64(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`) + store := LoadStore() + jsonStr, b64Str, err := store.ExportJSONAndBase64() + if err != nil { + t.Fatalf("export error: %v", err) + } + if !strings.Contains(jsonStr, "export-key") { + t.Fatalf("expected JSON to contain key: %q", jsonStr) + } + decoded, err := base64.StdEncoding.DecodeString(b64Str) + if err != nil { + t.Fatalf("base64 decode error: %v", err) + } + if !strings.Contains(string(decoded), "export-key") { + t.Fatalf("expected base64-decoded to contain key: %q", string(decoded)) + } +} + +// ─── OpenAIModelsResponse / ClaudeModelsResponse ───────────────────── + +func TestOpenAIModelsResponse(t *testing.T) { + resp := OpenAIModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} + +func TestClaudeModelsResponse(t *testing.T) { + resp := ClaudeModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} diff --git a/internal/deepseek/deepseek_edge_test.go b/internal/deepseek/deepseek_edge_test.go new file mode 100644 index 0000000..92e6952 --- /dev/null +++ b/internal/deepseek/deepseek_edge_test.go @@ -0,0 +1,165 @@ +package deepseek + +import ( + "context" + "testing" +) + +// ─── toFloat64 edge cases ──────────────────────────────────────────── + +func TestToFloat64FromFloat64(t *testing.T) { + if got := toFloat64(float64(3.14), 0); got != 3.14 { + t.Fatalf("expected 3.14, got %f", got) + } +} + +func TestToFloat64FromInt(t *testing.T) { + if got := toFloat64(42, 0); got != 42.0 { + t.Fatalf("expected 42.0, got %f", got) + } +} + +func TestToFloat64FromInt64(t *testing.T) { + if got := toFloat64(int64(100), 0); got != 100.0 { + t.Fatalf("expected 100.0, got %f", got) + } +} + +func TestToFloat64FromStringDefault(t *testing.T) { + if got := toFloat64("42", 99.0); got != 99.0 { + t.Fatalf("expected default 99.0, got %f", got) + } +} + +func TestToFloat64FromNilDefault(t *testing.T) { + if got := toFloat64(nil, 5.5); got != 5.5 { + t.Fatalf("expected default 5.5, got %f", got) + } +} + +func TestToFloat64FromBoolDefault(t *testing.T) { + if got := toFloat64(true, 1.0); got != 1.0 { + t.Fatalf("expected default 1.0, got %f", got) + } +} + +// ─── toInt64 edge cases ────────────────────────────────────────────── + +func TestToInt64FromFloat64(t *testing.T) { + if got := toInt64(float64(42.9), 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt(t *testing.T) { + if got := toInt64(42, 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt64(t *testing.T) { + if got := toInt64(int64(100), 0); got != 100 { + t.Fatalf("expected 100, got %d", got) + } +} + +func TestToInt64FromStringDefault(t *testing.T) { + if got := toInt64("42", 99); got != 99 { + t.Fatalf("expected default 99, got %d", got) + } +} + +func TestToInt64FromNilDefault(t *testing.T) { + if got := toInt64(nil, 7); got != 7 { + t.Fatalf("expected default 7, got %d", got) + } +} + +// ─── BuildPowHeader edge cases ─────────────────────────────────────── + +func TestBuildPowHeaderBasicChallenge(t *testing.T) { + challenge := map[string]any{ + "algorithm": "DeepSeekHashV1", + "challenge": "abc123", + "salt": "salt456", + "signature": "sig789", + "target_path": "/path", + } + result, err := BuildPowHeader(challenge, 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Fatal("expected non-empty result") + } +} + +func TestBuildPowHeaderEmptyChallenge(t *testing.T) { + result, err := BuildPowHeader(map[string]any{}, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should produce a base64 encoded JSON with nil values + if result == "" { + t.Fatal("expected non-empty result for empty challenge") + } +} + +// ─── PowSolver pool size ───────────────────────────────────────────── + +func TestPowPoolSizeFromEnvDefault(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default pool size, got %d", got) + } +} + +func TestPowPoolSizeFromEnvInvalid(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "abc") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default for invalid, got %d", got) + } +} + +func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "5") + got := powPoolSizeFromEnv() + if got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +// ─── NewClient ─────────────────────────────────────────────────────── + +func TestNewClientInitialState(t *testing.T) { + client := NewClient(nil, nil) + if client.powSolver == nil { + t.Fatal("expected powSolver to be initialized") + } +} + +func TestNewClientPreloadPowIdempotent(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + client := NewClient(nil, nil) + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("first preload failed: %v", err) + } + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("second preload failed: %v", err) + } +} + +// ─── PowSolver init and module pool ────────────────────────────────── + +func TestPowSolverPoolSizeMatchesEnv(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "2") + solver := NewPowSolver("test.wasm") + if err := solver.init(context.Background()); err != nil { + t.Fatalf("init failed: %v", err) + } + if cap(solver.pool) != 2 { + t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool)) + } +} diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go new file mode 100644 index 0000000..8f78f01 --- /dev/null +++ b/internal/sse/consumer_edge_test.go @@ -0,0 +1,140 @@ +package sse + +import ( + "io" + "net/http" + "strings" + "testing" +) + +// ─── CollectStream edge cases ──────────────────────────────────────── + +func makeHTTPResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestCollectStreamEmpty(t *testing.T) { + resp := makeHTTPResponse("") + result := CollectStream(resp, false, false) + if result.Text != "" || result.Thinking != "" { + t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking) + } +} + +func TestCollectStreamTextOnly(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" World\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello World" { + t.Fatalf("expected 'Hello World', got %q", result.Text) + } + if result.Thinking != "" { + t.Fatalf("expected no thinking, got %q", result.Thinking) + } +} + +func TestCollectStreamThinkingAndText(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Thinking..." { + t.Fatalf("expected 'Thinking...', got %q", result.Thinking) + } + if result.Text != "Answer" { + t.Fatalf("expected 'Answer', got %q", result.Text) + } +} + +func TestCollectStreamOnlyThinking(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Only thinking" { + t.Fatalf("expected 'Only thinking', got %q", result.Thinking) + } + if result.Text != "" { + t.Fatalf("expected empty text, got %q", result.Text) + } +} + +func TestCollectStreamSkipsInvalidLines(t *testing.T) { + resp := makeHTTPResponse( + "event: comment\n" + + "data: invalid_json\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "valid" { + t.Fatalf("expected 'valid', got %q", result.Text) + } +} + +func TestCollectStreamWithFragments(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Think" { + t.Fatalf("expected 'Think' thinking, got %q", result.Thinking) + } + if result.Text != "Done" { + t.Fatalf("expected 'Done' text, got %q", result.Text) + } +} + +func TestCollectStreamWithCitation(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" more\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + // CollectStream does NOT filter citations (that's done by the adapters) + // So citations are passed through as-is + if !strings.Contains(result.Text, "[citation:1]") { + t.Fatalf("expected citations to be passed through, got %q", result.Text) + } + if result.Text != "Hello[citation:1] cited text more" { + t.Fatalf("expected full text with citation, got %q", result.Text) + } +} + +func TestCollectStreamMultipleThinkingChunks(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" + + "data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "part1 part2" { + t.Fatalf("expected 'part1 part2', got %q", result.Thinking) + } +} + +func TestCollectStreamStatusFinished(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello" { + t.Fatalf("expected 'Hello', got %q", result.Text) + } +} diff --git a/internal/sse/line_edge_test.go b/internal/sse/line_edge_test.go new file mode 100644 index 0000000..2ae53a6 --- /dev/null +++ b/internal/sse/line_edge_test.go @@ -0,0 +1,70 @@ +package sse + +import "testing" + +func TestParseDeepSeekContentLineNotParsed(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed") + } + if res.NextType != "text" { + t.Fatalf("expected nextType preserved, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking") + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" { + t.Fatalf("unexpected parts: %#v", res.Parts) + } +} + +func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) { + res := ParseDeepSeekContentLine( + []byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`), + true, "thinking", + ) + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if res.NextType != "text" { + t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text") + if !res.ContentFilter { + t.Fatal("expected content filter flag") + } + if res.ErrorMessage == "" { + t.Fatal("expected error message on content filter") + } +} + +func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text") + if !res.Parsed || !res.Stop { + t.Fatalf("expected parsed stop: %#v", res) + } + if res.ErrorMessage == "" { + t.Fatal("expected non-empty error message") + } +} + +func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed for broken JSON") + } +} + +func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) { + res := ParseDeepSeekContentLine([]byte{}, false, "text") + if res.Parsed { + t.Fatal("expected not parsed for empty bytes") + } +} diff --git a/internal/sse/parser_edge_test.go b/internal/sse/parser_edge_test.go new file mode 100644 index 0000000..c851c1f --- /dev/null +++ b/internal/sse/parser_edge_test.go @@ -0,0 +1,631 @@ +package sse + +import "testing" + +// ─── ParseDeepSeekSSELine edge cases ───────────────────────────────── + +func TestParseDeepSeekSSELineEmptyLine(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("")) + if ok { + t.Fatal("expected not parsed for empty line") + } +} + +func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("event: message")) + if ok { + t.Fatal("expected not parsed for non-data line") + } +} + +func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json")) + if ok { + t.Fatal("expected not parsed for invalid JSON") + } +} + +func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte(" ")) + if ok { + t.Fatal("expected not parsed for whitespace-only line") + } +} + +func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) { + chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `)) + if !ok || done { + t.Fatalf("expected parsed chunk for spaced data line") + } + if chunk["v"] != "hello" { + t.Fatalf("unexpected chunk: %#v", chunk) + } +} + +// ─── shouldSkipPath edge cases ─────────────────────────────────────── + +func TestShouldSkipPathQuasiStatus(t *testing.T) { + if !shouldSkipPath("response/quasi_status") { + t.Fatal("expected skip for quasi_status path") + } +} + +func TestShouldSkipPathElapsedSecs(t *testing.T) { + if !shouldSkipPath("response/elapsed_secs") { + t.Fatal("expected skip for elapsed_secs path") + } +} + +func TestShouldSkipPathTokenUsage(t *testing.T) { + if !shouldSkipPath("response/token_usage") { + t.Fatal("expected skip for token_usage path") + } +} + +func TestShouldSkipPathPendingFragment(t *testing.T) { + if !shouldSkipPath("response/pending_fragment") { + t.Fatal("expected skip for pending_fragment path") + } +} + +func TestShouldSkipPathConversationMode(t *testing.T) { + if !shouldSkipPath("response/conversation_mode") { + t.Fatal("expected skip for conversation_mode path") + } +} + +func TestShouldSkipPathSearchStatus(t *testing.T) { + if !shouldSkipPath("response/search_status") { + t.Fatal("expected skip for search_status path") + } +} + +func TestShouldSkipPathFragmentStatus(t *testing.T) { + if !shouldSkipPath("response/fragments/-1/status") { + t.Fatal("expected skip for fragment -1 status") + } + if !shouldSkipPath("response/fragments/-2/status") { + t.Fatal("expected skip for fragment -2 status") + } + if !shouldSkipPath("response/fragments/-3/status") { + t.Fatal("expected skip for fragment -3 status") + } +} + +func TestShouldSkipPathRegularContent(t *testing.T) { + if shouldSkipPath("response/content") { + t.Fatal("expected not skip for content path") + } + if shouldSkipPath("response/thinking_content") { + t.Fatal("expected not skip for thinking_content path") + } +} + +// ─── ParseSSEChunkForContent edge cases ────────────────────────────── + +func TestParseSSEChunkForContentNoVField(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts when v is missing, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected type preserved, got %q", nextType) + } +} + +func TestParseSSEChunkForContentSkippedPath(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{ + "p": "response/token_usage", + "v": "some data", + }, false, "text") + if finished || len(parts) > 0 { + t.Fatalf("expected skipped path to produce no output") + } + if nextType != "text" { + t.Fatalf("expected type preserved for skipped path") + } +} + +func TestParseSSEChunkForContentFinishedStatus(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished, got %d", len(parts)) + } +} + +func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "IN_PROGRESS", + }, false, "text") + if finished { + t.Fatal("expected not finished for non-FINISHED status") + } + if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" { + t.Fatalf("expected content for non-FINISHED status, got %#v", parts) + } +} + +func TestParseSSEChunkForContentEmptyStringV(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/content", + "v": "", + }, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty string v, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on empty path with FINISHED value") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished") + } +} + +func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) { + _, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status path with FINISHED value") + } +} + +func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "some thought", + }, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking part on empty path, got %#v", parts) + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } +} + +func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "text content", + }, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text part when currentType=text, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: fragments path with THINK type ───────── + +func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINK", + "content": "深入思考...", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty content, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "UNKNOWN", + "content": "some text", + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for unknown fragment type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": "not an array", + } + parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + // "not an array" should be treated as string value at the end + if len(parts) != 1 || parts[0].Text != "not an array" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{"string item"}, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + // Non-map items in fragment array are skipped; the []any itself is handled later + _ = parts // just checking it doesn't panic +} + +// ─── ParseSSEChunkForContent: response path with nested fragment ───── + +func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINKING", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "text") + if nextType != "thinking" { + t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType) + } +} + +func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if nextType != "text" { + t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: map value with wrapped response ──────── + +func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "response": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "THINK", + "content": "思考...", + }, + map[string]any{ + "type": "RESPONSE", + "content": "回答...", + }, + }, + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text after RESPONSE, got %q", nextType) + } + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "思考..." { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "回答..." { + t.Fatalf("second part mismatch: %#v", parts[1]) + } +} + +func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "直接回答", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" { + t.Fatalf("unexpected parts for direct fragments: %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "CUSTOM", + "content": "custom content", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected partType fallback for unknown type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty fragment content, got %#v", parts) + } +} + +// ─── ParseSSEChunkForContent: fragments/-1/content path ────────────── + +func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments/-1/content", + "v": "继续思考", + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected inherited thinking type, got %#v", parts) + } +} + +// ─── IsCitation edge cases ─────────────────────────────────────────── + +func TestIsCitationWithLeadingWhitespace(t *testing.T) { + if !IsCitation(" [citation:2] text") { + t.Fatal("expected citation true with leading whitespace") + } +} + +func TestIsCitationEmpty(t *testing.T) { + if IsCitation("") { + t.Fatal("expected citation false for empty string") + } +} + +func TestIsCitationSimilarPrefix(t *testing.T) { + if IsCitation("[cite:1] text") { + t.Fatal("expected citation false for [cite: prefix") + } +} + +// ─── extractContentRecursive edge cases ────────────────────────────── + +func TestExtractContentRecursiveFinishedStatus(t *testing.T) { + items := []any{ + map[string]any{"p": "status", "v": "FINISHED"}, + } + parts, finished := extractContentRecursive(items, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts, got %#v", parts) + } +} + +func TestExtractContentRecursiveSkipsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "token_usage", "v": "data"}, + } + parts, finished := extractContentRecursive(items, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for skipped path, got %#v", parts) + } +} + +func TestExtractContentRecursiveContentField(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking type for THINK content, got %#v", parts) + } +} + +func TestExtractContentRecursiveThinkingPath(t *testing.T) { + items := []any{ + map[string]any{"p": "thinking_content", "v": "deep thought"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" { + t.Fatalf("unexpected parts for thinking path: %#v", parts) + } +} + +func TestExtractContentRecursiveContentPath(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for content path, got %#v", parts) + } +} + +func TestExtractContentRecursiveResponsePath(t *testing.T) { + items := []any{ + map[string]any{"p": "response", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for response path, got %#v", parts) + } +} + +func TestExtractContentRecursiveFragmentsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "fragments", "v": "fragment text"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for fragments path, got %#v", parts) + } +} + +func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "thought", "type": "THINKING"}, + map[string]any{"content": "answer", "type": "RESPONSE"}, + "raw string", + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "thought" { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "answer" { + t.Fatalf("second part mismatch: %#v", parts[1]) + } + if parts[2].Type != "text" || parts[2].Text != "raw string" { + t.Fatalf("third part mismatch: %#v", parts[2]) + } +} + +func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "", "type": "RESPONSE"}, + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty nested content, got %#v", parts) + } +} + +func TestExtractContentRecursiveFinishedString(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "FINISHED"}, + } + parts, _ := extractContentRecursive(items, "text") + // "FINISHED" string value on non-status path should be skipped + if len(parts) != 0 { + t.Fatalf("expected FINISHED string to be skipped, got %#v", parts) + } +} + +func TestExtractContentRecursiveNoVField(t *testing.T) { + items := []any{ + map[string]any{"p": "content"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for missing v field, got %#v", parts) + } +} + +func TestExtractContentRecursiveNonMapItem(t *testing.T) { + items := []any{"just a string", 42} + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for non-map items, got %#v", parts) + } +} diff --git a/internal/sse/stream_edge_test.go b/internal/sse/stream_edge_test.go new file mode 100644 index 0000000..927b023 --- /dev/null +++ b/internal/sse/stream_edge_test.go @@ -0,0 +1,177 @@ +package sse + +import ( + "context" + "io" + "strings" + "testing" +) + +func TestStartParsedLinePumpEmptyBody(t *testing.T) { + body := strings.NewReader("") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) != 0 { + t.Fatalf("expected no results for empty body, got %d", len(collected)) + } +} + +func TestStartParsedLinePumpMultipleLines(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"text\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "thinking") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(collected)) + } + // First should be thinking + if collected[0].Parts[0].Type != "thinking" { + t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type) + } + // Last should be stop + last := collected[len(collected)-1] + if !last.Stop { + t.Fatal("expected last result to be stop") + } +} + +func TestStartParsedLinePumpTypeTracking(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "text") + + types := make([]string, 0) + for r := range results { + for _, p := range r.Parts { + types = append(types, p.Type) + } + } + <-done + + // Should have: thinking, thinking, text, text + expected := []string{"thinking", "thinking", "text", "text"} + if len(types) != len(expected) { + t.Fatalf("expected types %v, got %v", expected, types) + } + for i, want := range expected { + if types[i] != want { + t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types) + } + } +} + +func TestStartParsedLinePumpContextCancellation(t *testing.T) { + pr, pw := io.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + results, done := StartParsedLinePump(ctx, pr, false, "text") + + // Write one line to allow it to start + go func() { + _, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n") + // Don't close yet - wait for context cancel + }() + + // Read first result + r := <-results + if !r.Parsed || len(r.Parts) == 0 { + t.Fatalf("expected first parsed result, got %#v", r) + } + + // Cancel context - this will cause the pump to exit on next send + cancel() + // Close the pipe to unblock scanner.Scan() + pw.Close() + + // Drain remaining results + for range results { + } + + err := <-done + // Error may be context.Canceled or nil (if pipe closed first) + if err != nil && err != context.Canceled { + t.Fatalf("expected context.Canceled or nil error, got %v", err) + } +} + +func TestStartParsedLinePumpOnlyDONE(t *testing.T) { + body := strings.NewReader("data: [DONE]\n") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + <-done + + if len(collected) != 1 { + t.Fatalf("expected 1 result, got %d", len(collected)) + } + if !collected[0].Stop { + t.Fatal("expected stop on [DONE]") + } +} + +func TestStartParsedLinePumpNonSSELines(t *testing.T) { + body := strings.NewReader( + "event: update\n" + + ": comment line\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var validCount int + for r := range results { + if r.Parsed && len(r.Parts) > 0 { + validCount++ + } + } + <-done + + if validCount != 1 { + t.Fatalf("expected 1 valid result, got %d", validCount) + } +} + +func TestStartParsedLinePumpThinkingDisabled(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"response\"}\n" + + "data: [DONE]\n", + ) + // With thinking disabled, thinking content should still be emitted but marked differently + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var parts []ContentPart + for r := range results { + parts = append(parts, r.Parts...) + } + <-done + + if len(parts) < 1 { + t.Fatalf("expected at least 1 part, got %d", len(parts)) + } +} diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go new file mode 100644 index 0000000..393aa88 --- /dev/null +++ b/internal/util/util_edge_test.go @@ -0,0 +1,441 @@ +package util + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "ds2api/internal/config" +) + +// ─── EstimateTokens edge cases ─────────────────────────────────────── + +func TestEstimateTokensEmpty(t *testing.T) { + if got := EstimateTokens(""); got != 0 { + t.Fatalf("expected 0 for empty string, got %d", got) + } +} + +func TestEstimateTokensShortASCII(t *testing.T) { + got := EstimateTokens("ab") + if got != 1 { + t.Fatalf("expected 1 for 2 ascii chars, got %d", got) + } +} + +func TestEstimateTokensLongASCII(t *testing.T) { + got := EstimateTokens(strings.Repeat("x", 100)) + if got != 25 { + t.Fatalf("expected 25 for 100 ascii chars, got %d", got) + } +} + +func TestEstimateTokensChinese(t *testing.T) { + got := EstimateTokens("你好世界") + if got < 1 { + t.Fatalf("expected at least 1 token for Chinese text, got %d", got) + } +} + +func TestEstimateTokensMixed(t *testing.T) { + got := EstimateTokens("Hello 你好世界") + if got < 2 { + t.Fatalf("expected at least 2 tokens for mixed text, got %d", got) + } +} + +func TestEstimateTokensSingleByte(t *testing.T) { + got := EstimateTokens("x") + if got != 1 { + t.Fatalf("expected 1 for single char (minimum), got %d", got) + } +} + +func TestEstimateTokensSingleChinese(t *testing.T) { + got := EstimateTokens("你") + if got != 1 { + t.Fatalf("expected 1 for single Chinese char, got %d", got) + } +} + +// ─── ToBool edge cases ─────────────────────────────────────────────── + +func TestToBoolTrue(t *testing.T) { + if !ToBool(true) { + t.Fatal("expected true") + } +} + +func TestToBoolFalse(t *testing.T) { + if ToBool(false) { + t.Fatal("expected false") + } +} + +func TestToBoolNonBool(t *testing.T) { + if ToBool("true") { + t.Fatal("expected false for string 'true'") + } + if ToBool(1) { + t.Fatal("expected false for int 1") + } + if ToBool(nil) { + t.Fatal("expected false for nil") + } +} + +// ─── IntFrom edge cases ───────────────────────────────────────────── + +func TestIntFromFloat64(t *testing.T) { + if got := IntFrom(float64(42.5)); got != 42 { + t.Fatalf("expected 42 for float64(42.5), got %d", got) + } +} + +func TestIntFromInt(t *testing.T) { + if got := IntFrom(int(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromInt64(t *testing.T) { + if got := IntFrom(int64(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromString(t *testing.T) { + if got := IntFrom("42"); got != 0 { + t.Fatalf("expected 0 for string, got %d", got) + } +} + +func TestIntFromNil(t *testing.T) { + if got := IntFrom(nil); got != 0 { + t.Fatalf("expected 0 for nil, got %d", got) + } +} + +// ─── WriteJSON ─────────────────────────────────────────────────────── + +func TestWriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + WriteJSON(rec, 200, map[string]any{"key": "value"}) + if rec.Code != 200 { + t.Fatalf("expected 200, got %d", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected application/json content type, got %q", ct) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode error: %v", err) + } + if body["key"] != "value" { + t.Fatalf("unexpected body: %#v", body) + } +} + +func TestWriteJSONStatusCodes(t *testing.T) { + for _, code := range []int{200, 201, 400, 404, 500} { + rec := httptest.NewRecorder() + WriteJSON(rec, code, map[string]any{"status": code}) + if rec.Code != code { + t.Fatalf("expected %d, got %d", code, rec.Code) + } + } +} + +// ─── MessagesPrepare edge cases ────────────────────────────────────── + +func TestMessagesPrepareEmpty(t *testing.T) { + got := MessagesPrepare(nil) + if got != "" { + t.Fatalf("expected empty for nil messages, got %q", got) + } +} + +func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "World"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") { + t.Fatalf("expected both messages, got %q", got) + } + // Should be merged without <|User|> between them + count := strings.Count(got, "<|User|>") + if count != 0 { + t.Fatalf("expected no User marker for first message pair, got %d occurrences", count) + } +} + +func TestMessagesPrepareAssistantMarkers(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "<|Assistant|>") { + t.Fatalf("expected assistant marker, got %q", got) + } + if !strings.Contains(got, "<|end▁of▁sentence|>") { + t.Fatalf("expected end of sentence marker, got %q", got) + } +} + +func TestMessagesPrepareUnknownRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "unknown_role", "content": "Unknown"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Unknown") { + t.Fatalf("expected unknown role content, got %q", got) + } +} + +func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Look at this: ![alt](https://example.com/img.png)"}, + } + got := MessagesPrepare(messages) + if strings.Contains(got, "![alt]") { + t.Fatalf("expected markdown image to be replaced, got %q", got) + } +} + +func TestMessagesPrepareNilContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": nil}, + } + got := MessagesPrepare(messages) + if got != "null" { + t.Logf("nil content handled as: %q", got) + } +} + +// ─── normalizeContent edge cases ───────────────────────────────────── + +func TestNormalizeContentString(t *testing.T) { + got := normalizeContent("hello") + if got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestNormalizeContentArray(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }) + if got != "line1\nline2" { + t.Fatalf("expected 'line1\\nline2', got %q", got) + } +} + +func TestNormalizeContentArrayWithContentField(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "content": "from-content"}, + }) + if got != "from-content" { + t.Fatalf("expected 'from-content', got %q", got) + } +} + +func TestNormalizeContentArraySkipsImage(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"}, + map[string]any{"type": "text", "text": "caption"}, + }) + if strings.Contains(got, "image") { + t.Fatalf("expected image skipped, got %q", got) + } + if got != "caption" { + t.Fatalf("expected 'caption', got %q", got) + } +} + +func TestNormalizeContentArrayNonMapItems(t *testing.T) { + got := normalizeContent([]any{"string item", 42}) + if got != "" { + t.Fatalf("expected empty for non-map items, got %q", got) + } +} + +func TestNormalizeContentJSON(t *testing.T) { + got := normalizeContent(map[string]any{"key": "value"}) + if !strings.Contains(got, `"key":"value"`) { + t.Fatalf("expected JSON serialized, got %q", got) + } +} + +// ─── ConvertClaudeToDeepSeek edge cases ────────────────────────────── + +func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] == "" { + t.Fatal("expected default model") + } +} + +func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "stop_sequences": []any{"\n\n"}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["stop"] == nil { + t.Fatal("expected stop field from stop_sequences") + } +} + +func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "temperature": 0.7, + "top_p": 0.9, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["temperature"] != 0.7 { + t.Fatalf("expected temperature 0.7, got %v", out["temperature"]) + } + if out["top_p"] != 0.9 { + t.Fatalf("expected top_p 0.9, got %v", out["top_p"]) + } +} + +func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + msgs, _ := out["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("expected 1 message without system, got %d", len(msgs)) + } +} + +func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] != "deepseek-reasoner" { + t.Fatalf("expected opus to use slow mapping, got %q", out["model"]) + } +} + +// ─── FormatOpenAIStreamToolCalls ───────────────────────────────────── + +func TestFormatOpenAIStreamToolCalls(t *testing.T) { + formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{ + {Name: "search", Input: map[string]any{"q": "test"}}, + }) + if len(formatted) != 1 { + t.Fatalf("expected 1, got %d", len(formatted)) + } + fn, _ := formatted[0]["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn) + } + if formatted[0]["index"] != 0 { + t.Fatalf("expected index 0, got %v", formatted[0]["index"]) + } +} + +// ─── ParseToolCalls more edge cases ────────────────────────────────── + +func TestParseToolCallsNoToolNames(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseToolCalls(text, nil) + if len(calls) != 1 { + t.Fatalf("expected 1 call with nil tool names, got %d", len(calls)) + } +} + +func TestParseToolCallsEmptyText(t *testing.T) { + calls := ParseToolCalls("", []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected 0 calls for empty text, got %d", len(calls)) + } +} + +func TestParseToolCallsMultipleTools(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}` + calls := ParseToolCalls(text, []string{"search", "get_weather"}) + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %d", len(calls)) + } +} + +func TestParseToolCallsInputAsString(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}` + calls := ParseToolCalls(text, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Input["q"] != "golang" { + t.Fatalf("expected parsed string input, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsWithFunctionWrapper(t *testing.T) { + text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}` + calls := ParseToolCalls(text, []string{"calc"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Name != "calc" { + t.Fatalf("expected calc, got %q", calls[0].Name) + } +} + +func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) { + fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this." + calls := ParseStandaloneToolCalls(fenced, []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected fenced code block ignored, got %d calls", len(calls)) + } +} + +// ─── looksLikeToolExampleContext ───────────────────────────────────── + +func TestLooksLikeToolExampleContextChinese(t *testing.T) { + if !looksLikeToolExampleContext("下面是示例") { + t.Fatal("expected true for Chinese example context") + } +} + +func TestLooksLikeToolExampleContextEnglish(t *testing.T) { + if !looksLikeToolExampleContext("here is an example of") { + t.Fatal("expected true for English example context") + } +} + +func TestLooksLikeToolExampleContextNone(t *testing.T) { + if looksLikeToolExampleContext("I will call the tool now") { + t.Fatal("expected false for non-example context") + } +} + +func TestLooksLikeToolExampleContextFenced(t *testing.T) { + if !looksLikeToolExampleContext("```json") { + t.Fatal("expected true for fenced code block context") + } +}