From fe8a6bd3cd8ec6b08eef9c84f5737402210fcd95 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Wed, 22 Apr 2026 16:22:04 +0000 Subject: [PATCH] refactor: improve chat history persistence reliability with metadata-only migration, error handling, and optimized file updates --- internal/adapter/openai/chat_history.go | 142 +++++++-------- internal/adapter/openai/chat_history_test.go | 99 ++++++++++ internal/chathistory/store.go | 180 ++++++++++++------- internal/chathistory/store_test.go | 176 ++++++++++++++++++ internal/sse/consumer.go | 9 +- internal/sse/consumer_edge_test.go | 34 ++++ 6 files changed, 490 insertions(+), 150 deletions(-) diff --git a/internal/adapter/openai/chat_history.go b/internal/adapter/openai/chat_history.go index b6435de..e71ede5 100644 --- a/internal/adapter/openai/chat_history.go +++ b/internal/adapter/openai/chat_history.go @@ -1,6 +1,7 @@ package openai import ( + "errors" "net/http" "strings" "time" @@ -45,10 +46,6 @@ func startChatHistory(store *chathistory.Store, r *http.Request, a *auth.Request Messages: extractAllMessages(stdReq.Messages), FinalPrompt: stdReq.FinalPrompt, }) - if err != nil { - config.Logger.Warn("[chat_history] start failed", "error", err) - return nil - } startParams := chathistory.StartParams{ CallerID: strings.TrimSpace(a.CallerID), AccountID: strings.TrimSpace(a.AccountID), @@ -58,7 +55,7 @@ func startChatHistory(store *chathistory.Store, r *http.Request, a *auth.Request Messages: extractAllMessages(stdReq.Messages), FinalPrompt: stdReq.FinalPrompt, } - return &chatHistorySession{ + session := &chatHistorySession{ store: store, entryID: entry.ID, startedAt: time.Now(), @@ -66,6 +63,14 @@ func startChatHistory(store *chathistory.Store, r *http.Request, a *auth.Request finalPrompt: stdReq.FinalPrompt, startParams: startParams, } + if err != nil { + if entry.ID == "" { + config.Logger.Warn("[chat_history] start failed", "error", err) + return nil + } + config.Logger.Warn("[chat_history] start persisted in memory after write failure", "error", err) + } + return session } func shouldCaptureChatHistory(r *http.Request) bool { @@ -124,33 +129,20 @@ func (s *chatHistorySession) progress(thinking, content string) { return } s.lastPersist = now - if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{ + s.persistUpdate(chathistory.UpdateParams{ Status: "streaming", ReasoningContent: thinking, Content: content, StatusCode: http.StatusOK, ElapsedMs: time.Since(s.startedAt).Milliseconds(), - }); err != nil { - if !s.retryMissingEntry() { - s.disableOnMissing(err) - return - } - _, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{ - Status: "streaming", - ReasoningContent: thinking, - Content: content, - StatusCode: http.StatusOK, - ElapsedMs: time.Since(s.startedAt).Milliseconds(), - }) - s.disableOnMissing(retryErr) - } + }) } func (s *chatHistorySession) success(statusCode int, thinking, content, finishReason string, usage map[string]any) { if s == nil || s.store == nil || s.disabled { return } - if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{ + s.persistUpdate(chathistory.UpdateParams{ Status: "success", ReasoningContent: thinking, Content: content, @@ -159,30 +151,14 @@ func (s *chatHistorySession) success(statusCode int, thinking, content, finishRe FinishReason: finishReason, Usage: usage, Completed: true, - }); err != nil { - if !s.retryMissingEntry() { - s.disableOnMissing(err) - return - } - _, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{ - Status: "success", - ReasoningContent: thinking, - Content: content, - StatusCode: statusCode, - ElapsedMs: time.Since(s.startedAt).Milliseconds(), - FinishReason: finishReason, - Usage: usage, - Completed: true, - }) - s.disableOnMissing(retryErr) - } + }) } func (s *chatHistorySession) error(statusCode int, message, finishReason, thinking, content string) { if s == nil || s.store == nil || s.disabled { return } - if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{ + s.persistUpdate(chathistory.UpdateParams{ Status: "error", ReasoningContent: thinking, Content: content, @@ -191,30 +167,14 @@ func (s *chatHistorySession) error(statusCode int, message, finishReason, thinki ElapsedMs: time.Since(s.startedAt).Milliseconds(), FinishReason: finishReason, Completed: true, - }); err != nil { - if !s.retryMissingEntry() { - s.disableOnMissing(err) - return - } - _, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{ - Status: "error", - ReasoningContent: thinking, - Content: content, - Error: message, - StatusCode: statusCode, - ElapsedMs: time.Since(s.startedAt).Milliseconds(), - FinishReason: finishReason, - Completed: true, - }) - s.disableOnMissing(retryErr) - } + }) } func (s *chatHistorySession) stopped(thinking, content, finishReason string) { if s == nil || s.store == nil || s.disabled { return } - if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{ + s.persistUpdate(chathistory.UpdateParams{ Status: "stopped", ReasoningContent: thinking, Content: content, @@ -223,23 +183,7 @@ func (s *chatHistorySession) stopped(thinking, content, finishReason string) { FinishReason: finishReason, Usage: openaifmt.BuildChatUsage(s.finalPrompt, thinking, content), Completed: true, - }); err != nil { - if !s.retryMissingEntry() { - s.disableOnMissing(err) - return - } - _, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{ - Status: "stopped", - ReasoningContent: thinking, - Content: content, - StatusCode: http.StatusOK, - ElapsedMs: time.Since(s.startedAt).Milliseconds(), - FinishReason: finishReason, - Usage: openaifmt.BuildChatUsage(s.finalPrompt, thinking, content), - Completed: true, - }) - s.disableOnMissing(retryErr) - } + }) } func (s *chatHistorySession) retryMissingEntry() bool { @@ -247,22 +191,60 @@ func (s *chatHistorySession) retryMissingEntry() bool { return false } entry, err := s.store.Start(s.startParams) - if err != nil { - s.disableOnMissing(err) + if errors.Is(err, chathistory.ErrDisabled) { + s.disabled = true + return false + } + if entry.ID == "" { + if err != nil { + config.Logger.Warn("[chat_history] recreate missing entry failed", "error", err) + } return false } s.entryID = entry.ID + if err != nil { + config.Logger.Warn("[chat_history] recreate missing entry persisted in memory after write failure", "error", err) + } return true } -func (s *chatHistorySession) disableOnMissing(err error) { +func (s *chatHistorySession) persistUpdate(params chathistory.UpdateParams) { + if s == nil || s.store == nil || s.disabled { + return + } + if _, err := s.store.Update(s.entryID, params); err != nil { + s.handlePersistError(params, err) + } +} + +func (s *chatHistorySession) handlePersistError(params chathistory.UpdateParams, err error) { if err == nil || s == nil { return } - if strings.Contains(strings.ToLower(err.Error()), "not found") { + if errors.Is(err, chathistory.ErrDisabled) { s.disabled = true return } - config.Logger.Warn("[chat_history] update disabled", "error", err) - s.disabled = true + if isChatHistoryMissingError(err) { + if s.retryMissingEntry() { + if _, retryErr := s.store.Update(s.entryID, params); retryErr != nil { + if errors.Is(retryErr, chathistory.ErrDisabled) || isChatHistoryMissingError(retryErr) { + s.disabled = true + return + } + config.Logger.Warn("[chat_history] retry after missing entry failed", "error", retryErr) + } + return + } + s.disabled = true + return + } + config.Logger.Warn("[chat_history] update failed", "error", err) +} + +func isChatHistoryMissingError(err error) bool { + if err == nil { + return false + } + return strings.Contains(strings.ToLower(err.Error()), "not found") } diff --git a/internal/adapter/openai/chat_history_test.go b/internal/adapter/openai/chat_history_test.go index 4500e1c..5e5d6a0 100644 --- a/internal/adapter/openai/chat_history_test.go +++ b/internal/adapter/openai/chat_history_test.go @@ -4,12 +4,16 @@ import ( "context" "net/http" "net/http/httptest" + "os" "path/filepath" "strings" + "sync" "testing" "time" + "ds2api/internal/auth" "ds2api/internal/chathistory" + "ds2api/internal/util" ) func newTestChatHistoryStore(t *testing.T) *chathistory.Store { @@ -21,6 +25,35 @@ func newTestChatHistoryStore(t *testing.T) *chathistory.Store { return store } +func blockChatHistoryDetailDir(t *testing.T, detailDir string) func() { + t.Helper() + blockedDir := detailDir + ".blocked" + if err := os.RemoveAll(blockedDir); err != nil { + t.Fatalf("remove blocked detail dir failed: %v", err) + } + if err := os.Rename(detailDir, blockedDir); err != nil { + t.Fatalf("move detail dir aside failed: %v", err) + } + if err := os.RemoveAll(detailDir); err != nil { + t.Fatalf("remove blocked detail path failed: %v", err) + } + if err := os.WriteFile(detailDir, []byte("blocked"), 0o644); err != nil { + t.Fatalf("write blocked detail path failed: %v", err) + } + var once sync.Once + return func() { + t.Helper() + once.Do(func() { + if err := os.RemoveAll(detailDir); err != nil { + t.Fatalf("remove blocking detail path failed: %v", err) + } + if err := os.Rename(blockedDir, detailDir); err != nil { + t.Fatalf("restore detail dir failed: %v", err) + } + }) + } +} + func TestChatCompletionsNonStreamPersistsHistory(t *testing.T) { historyStore := newTestChatHistoryStore(t) h := &Handler{ @@ -69,6 +102,72 @@ func TestChatCompletionsNonStreamPersistsHistory(t *testing.T) { } } +func TestStartChatHistoryRecoversFromTransientWriteFailure(t *testing.T) { + historyStore := newTestChatHistoryStore(t) + restore := blockChatHistoryDetailDir(t, historyStore.DetailDir()) + t.Cleanup(restore) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + a := &auth.RequestAuth{ + CallerID: "caller:test", + AccountID: "acct:test", + } + stdReq := util.StandardRequest{ + ResponseModel: "deepseek-chat", + Stream: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + FinalPrompt: "hello", + } + + session := startChatHistory(historyStore, req, a, stdReq) + if session == nil { + t.Fatalf("expected session even when initial persistence fails") + } + if session.disabled { + t.Fatalf("expected session to remain active after transient start failure") + } + if session.entryID == "" { + t.Fatalf("expected session entry id to be retained") + } + if err := historyStore.Err(); err != nil { + t.Fatalf("transient start failure should not latch store error: %v", err) + } + + session.lastPersist = time.Now().Add(-time.Second) + session.progress("thinking", "partial") + if session.disabled { + t.Fatalf("expected session to remain active after transient update failure") + } + if session.entryID == "" { + t.Fatalf("expected session entry id to remain set after update failure") + } + if err := historyStore.Err(); err != nil { + t.Fatalf("transient update failure should not latch store error: %v", err) + } + + restore() + + session.success(http.StatusOK, "thinking", "final answer", "stop", map[string]any{"total_tokens": 7}) + snapshot, err := historyStore.Snapshot() + if err != nil { + t.Fatalf("snapshot failed after restore: %v", err) + } + if len(snapshot.Items) != 1 { + t.Fatalf("expected one persisted item after restore, got %#v", snapshot.Items) + } + full, err := historyStore.Get(session.entryID) + if err != nil { + t.Fatalf("get restored entry failed: %v", err) + } + if full.Status != "success" || full.Content != "final answer" { + t.Fatalf("expected restored entry to persist final success, got %#v", full) + } +} + func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) { historyStore := newTestChatHistoryStore(t) entry, err := historyStore.Start(chathistory.StartParams{ diff --git a/internal/chathistory/store.go b/internal/chathistory/store.go index 716b953..711b001 100644 --- a/internal/chathistory/store.go +++ b/internal/chathistory/store.go @@ -118,12 +118,18 @@ type legacyFile struct { Items []Entry `json:"items"` } +type legacyProbe struct { + Items []map[string]json.RawMessage `json:"items"` +} + type Store struct { mu sync.Mutex path string detailDir string state File details map[string]Entry + dirty map[string]struct{} + deleted map[string]struct{} err error } @@ -138,6 +144,8 @@ func New(path string) *Store { Items: []SummaryEntry{}, }, details: map[string]Entry{}, + dirty: map[string]struct{}{}, + deleted: map[string]struct{}{}, } s.mu.Lock() defer s.mu.Unlock() @@ -237,9 +245,10 @@ func (s *Store) Start(params StartParams) (Entry, error) { FinalPrompt: strings.TrimSpace(params.FinalPrompt), } s.details[entry.ID] = entry + s.markDetailDirtyLocked(entry.ID) s.rebuildIndexLocked() if err := s.saveLocked(); err != nil { - return Entry{}, err + return cloneEntry(entry), err } return cloneEntry(entry), nil } @@ -280,6 +289,7 @@ func (s *Store) Update(id string, params UpdateParams) (Entry, error) { item.CompletedAt = now } s.details[target] = item + s.markDetailDirtyLocked(target) s.rebuildIndexLocked() if err := s.saveLocked(); err != nil { return Entry{}, err @@ -303,6 +313,7 @@ func (s *Store) Delete(id string) error { if _, ok := s.details[target]; !ok { return errors.New("chat history entry not found") } + s.markDetailDeletedLocked(target) delete(s.details, target) s.nextRevisionLocked() s.rebuildIndexLocked() @@ -321,6 +332,9 @@ func (s *Store) Clear() error { if s.err != nil { return s.err } + for id := range s.details { + s.markDetailDeletedLocked(id) + } s.details = map[string]Entry{} s.nextRevisionLocked() s.rebuildIndexLocked() @@ -374,7 +388,7 @@ func (s *Store) loadLocked() error { if legacyErr != nil { return legacyErr } - if legacyOK && !hasDetailFiles(s.detailDir) { + if legacyOK { s.loadLegacyLocked(legacy) return s.saveLocked() } @@ -409,6 +423,8 @@ func (s *Store) loadLegacyLocked(legacy legacyFile) { s.state.Limit = DefaultLimit } s.details = map[string]Entry{} + s.dirty = map[string]struct{}{} + s.deleted = map[string]struct{}{} maxRevision := int64(0) for _, item := range legacy.Items { if strings.TrimSpace(item.ID) == "" { @@ -426,6 +442,7 @@ func (s *Store) loadLegacyLocked(legacy legacyFile) { maxRevision = item.Revision } s.details[item.ID] = item + s.markDetailDirtyLocked(item.ID) } s.state.Revision = maxRevision s.rebuildIndexLocked() @@ -439,41 +456,40 @@ func (s *Store) saveLocked() error { s.rebuildIndexLocked() if err := os.MkdirAll(s.detailDir, 0o755); err != nil { - s.err = err - return err + return fmt.Errorf("create chat history detail dir: %w", err) } - activeFiles := make(map[string]struct{}, len(s.details)) - for id, item := range s.details { + for _, id := range sortedDetailIDs(s.deleted) { + path := filepath.Join(s.detailDir, id+".json") + if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("remove stale chat history detail: %w", err) + } + } + for _, id := range sortedDetailIDs(s.dirty) { + item, ok := s.details[id] + if !ok { + continue + } path := filepath.Join(s.detailDir, id+".json") - activeFiles[path] = struct{}{} payload, err := json.MarshalIndent(detailEnvelope{ Version: FileVersion, Item: item, }, "", " ") if err != nil { - s.err = err - return err + return fmt.Errorf("encode chat history detail: %w", err) } if err := writeFileAtomic(path, append(payload, '\n')); err != nil { - s.err = err return err } } - if err := cleanupDetailDir(s.detailDir, activeFiles); err != nil { - s.err = err - return err - } payload, err := json.MarshalIndent(s.state, "", " ") if err != nil { - s.err = err - return err + return fmt.Errorf("encode chat history index: %w", err) } if err := writeFileAtomic(s.path, append(payload, '\n')); err != nil { - s.err = err return err } - s.err = nil + s.clearPendingDetailChangesLocked() return nil } @@ -502,6 +518,7 @@ func (s *Store) rebuildIndexLocked() { } for id := range s.details { if _, ok := keep[id]; !ok { + s.markDetailDeletedLocked(id) delete(s.details, id) } } @@ -569,22 +586,6 @@ func readDetailFile(path string) (Entry, error) { return cloneEntry(env.Item), nil } -func hasDetailFiles(dir string) bool { - entries, err := os.ReadDir(dir) - if err != nil { - return false - } - for _, entry := range entries { - if entry.IsDir() { - continue - } - if strings.HasSuffix(strings.ToLower(entry.Name()), ".json") { - return true - } - } - return false -} - func parseLegacy(raw []byte) (legacyFile, bool, error) { var legacy legacyFile if err := json.Unmarshal(raw, &legacy); err != nil { @@ -593,32 +594,15 @@ func parseLegacy(raw []byte) (legacyFile, bool, error) { if len(legacy.Items) == 0 { return legacy, false, nil } - for _, item := range legacy.Items { - if item.Content != "" || item.ReasoningContent != "" || item.FinalPrompt != "" || len(item.Messages) > 0 { - return legacy, true, nil + var probe legacyProbe + if err := json.Unmarshal(raw, &probe); err == nil { + for _, item := range probe.Items { + if _, ok := item["detail_revision"]; ok { + return legacy, false, nil + } } } - return legacy, false, nil -} - -func cleanupDetailDir(dir string, active map[string]struct{}) error { - entries, err := os.ReadDir(dir) - if err != nil { - return fmt.Errorf("list chat history detail dir: %w", err) - } - for _, entry := range entries { - if entry.IsDir() { - continue - } - path := filepath.Join(dir, entry.Name()) - if _, ok := active[path]; ok { - continue - } - if err := os.Remove(path); err != nil { - return fmt.Errorf("remove stale chat history detail: %w", err) - } - } - return nil + return legacy, true, nil } func writeFileAtomic(path string, body []byte) error { @@ -636,25 +620,38 @@ func writeFileAtomic(path string, body []byte) error { return fmt.Errorf("create temp chat history: %w", err) } tmpPath := tmpFile.Name() - cleanup := func() { - _ = os.Remove(tmpPath) + cleanup := func() error { + if err := os.Remove(tmpPath); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("remove temp chat history: %w", err) + } + return nil + } + withCleanup := func(primary error, closeErr error) error { + errs := []error{primary} + if closeErr != nil { + errs = append(errs, fmt.Errorf("close temp chat history: %w", closeErr)) + } + if cleanupErr := cleanup(); cleanupErr != nil { + errs = append(errs, cleanupErr) + } + return errors.Join(errs...) } if _, err := tmpFile.Write(body); err != nil { - _ = tmpFile.Close() - cleanup() - return fmt.Errorf("write temp chat history: %w", err) + return withCleanup(fmt.Errorf("write temp chat history: %w", err), tmpFile.Close()) } if err := tmpFile.Sync(); err != nil { - _ = tmpFile.Close() - cleanup() - return fmt.Errorf("sync temp chat history: %w", err) + return withCleanup(fmt.Errorf("sync temp chat history: %w", err), tmpFile.Close()) } if err := tmpFile.Close(); err != nil { - cleanup() + if cleanupErr := cleanup(); cleanupErr != nil { + return errors.Join(fmt.Errorf("close temp chat history: %w", err), cleanupErr) + } return fmt.Errorf("close temp chat history: %w", err) } if err := os.Rename(tmpPath, path); err != nil { - cleanup() + if cleanupErr := cleanup(); cleanupErr != nil { + return errors.Join(fmt.Errorf("promote temp chat history: %w", err), cleanupErr) + } return fmt.Errorf("promote temp chat history: %w", err) } return nil @@ -673,6 +670,53 @@ func isAllowedLimit(limit int) bool { return ok } +func (s *Store) markDetailDirtyLocked(id string) { + id = strings.TrimSpace(id) + if id == "" { + return + } + if s.dirty == nil { + s.dirty = map[string]struct{}{} + } + if s.deleted == nil { + s.deleted = map[string]struct{}{} + } + s.dirty[id] = struct{}{} + delete(s.deleted, id) +} + +func (s *Store) markDetailDeletedLocked(id string) { + id = strings.TrimSpace(id) + if id == "" { + return + } + if s.dirty == nil { + s.dirty = map[string]struct{}{} + } + if s.deleted == nil { + s.deleted = map[string]struct{}{} + } + s.deleted[id] = struct{}{} + delete(s.dirty, id) +} + +func (s *Store) clearPendingDetailChangesLocked() { + s.dirty = map[string]struct{}{} + s.deleted = map[string]struct{}{} +} + +func sortedDetailIDs(ids map[string]struct{}) []string { + if len(ids) == 0 { + return nil + } + out := make([]string, 0, len(ids)) + for id := range ids { + out = append(out, id) + } + sort.Strings(out) + return out +} + func cloneFile(in File) File { out := File{ Version: in.Version, diff --git a/internal/chathistory/store_test.go b/internal/chathistory/store_test.go index d88d32f..a5830bd 100644 --- a/internal/chathistory/store_test.go +++ b/internal/chathistory/store_test.go @@ -1,6 +1,7 @@ package chathistory import ( + "bytes" "encoding/json" "os" "path/filepath" @@ -8,6 +9,35 @@ import ( "testing" ) +func blockDetailDir(t *testing.T, detailDir string) func() { + t.Helper() + blockedDir := detailDir + ".blocked" + if err := os.RemoveAll(blockedDir); err != nil { + t.Fatalf("remove blocked detail dir failed: %v", err) + } + if err := os.Rename(detailDir, blockedDir); err != nil { + t.Fatalf("move detail dir aside failed: %v", err) + } + if err := os.RemoveAll(detailDir); err != nil { + t.Fatalf("remove blocked detail path failed: %v", err) + } + if err := os.WriteFile(detailDir, []byte("blocked"), 0o644); err != nil { + t.Fatalf("write blocked detail path failed: %v", err) + } + var once sync.Once + return func() { + t.Helper() + once.Do(func() { + if err := os.RemoveAll(detailDir); err != nil { + t.Fatalf("remove blocking detail path failed: %v", err) + } + if err := os.Rename(blockedDir, detailDir); err != nil { + t.Fatalf("restore detail dir failed: %v", err) + } + }) + } +} + func TestStoreCreatesAndPersistsEntries(t *testing.T) { path := filepath.Join(t.TempDir(), "chat_history.json") store := New(path) @@ -254,3 +284,149 @@ func TestStoreAutoMigratesLegacyMonolith(t *testing.T) { t.Fatalf("expected migrated detail content preserved, got %#v", full) } } + +func TestStoreAutoMigratesMetadataOnlyLegacyMonolith(t *testing.T) { + path := filepath.Join(t.TempDir(), "chat_history.json") + legacy := legacyFile{ + Version: 1, + Limit: 20, + Items: []Entry{{ + ID: "chat_metadata_only", + Revision: 0, + CreatedAt: 1, + UpdatedAt: 2, + Status: "error", + CallerID: "caller:test", + AccountID: "acct:test", + Model: "deepseek-chat", + Stream: true, + UserInput: "hello", + Error: "boom", + StatusCode: 500, + ElapsedMs: 12, + FinishReason: "error", + }}, + } + body, _ := json.MarshalIndent(legacy, "", " ") + if err := os.WriteFile(path, body, 0o644); err != nil { + t.Fatalf("write legacy file failed: %v", err) + } + + store := New(path) + if err := store.Err(); err != nil { + t.Fatalf("expected legacy metadata-only migration success, got %v", err) + } + snapshot, err := store.Snapshot() + if err != nil { + t.Fatalf("snapshot failed: %v", err) + } + if len(snapshot.Items) != 1 { + t.Fatalf("expected one migrated summary, got %#v", snapshot.Items) + } + full, err := store.Get("chat_metadata_only") + if err != nil { + t.Fatalf("get migrated detail failed: %v", err) + } + if full.Error != "boom" || full.UserInput != "hello" { + t.Fatalf("expected metadata-only legacy fields preserved, got %#v", full) + } + if _, err := os.Stat(filepath.Join(store.DetailDir(), "chat_metadata_only.json")); err != nil { + t.Fatalf("expected migrated detail file to exist: %v", err) + } +} + +func TestStoreTransientPersistenceFailureDoesNotLatch(t *testing.T) { + path := filepath.Join(t.TempDir(), "chat_history.json") + store := New(path) + + first, err := store.Start(StartParams{UserInput: "first"}) + if err != nil { + t.Fatalf("start first failed: %v", err) + } + restore := blockDetailDir(t, store.DetailDir()) + t.Cleanup(restore) + + blocked, err := store.Start(StartParams{UserInput: "blocked"}) + if err == nil { + t.Fatalf("expected start failure while detail dir is blocked") + } + if blocked.ID == "" { + t.Fatalf("expected in-memory entry from failed start") + } + if err := store.Err(); err != nil { + t.Fatalf("transient start failure should not latch store error: %v", err) + } + if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "one", Completed: true}); err == nil { + t.Fatalf("expected update failure while detail dir is blocked") + } + if err := store.Err(); err != nil { + t.Fatalf("transient update failure should not latch store error: %v", err) + } + + restore() + + if _, err := store.Update(blocked.ID, UpdateParams{Status: "success", Content: "two", Completed: true}); err != nil { + t.Fatalf("update after restore failed: %v", err) + } + if _, err := store.Start(StartParams{UserInput: "later"}); err != nil { + t.Fatalf("start after restore failed: %v", err) + } + full, err := store.Get(blocked.ID) + if err != nil { + t.Fatalf("get restored entry failed: %v", err) + } + if full.Content != "two" || full.Status != "success" { + t.Fatalf("expected restored entry persisted, got %#v", full) + } +} + +func TestStoreWritesOnlyChangedDetailFiles(t *testing.T) { + path := filepath.Join(t.TempDir(), "chat_history.json") + store := New(path) + + first, err := store.Start(StartParams{UserInput: "one"}) + if err != nil { + t.Fatalf("start first failed: %v", err) + } + if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "first", Completed: true}); err != nil { + t.Fatalf("update first failed: %v", err) + } + second, err := store.Start(StartParams{UserInput: "two"}) + if err != nil { + t.Fatalf("start second failed: %v", err) + } + if _, err := store.Update(second.ID, UpdateParams{Status: "success", Content: "second", Completed: true}); err != nil { + t.Fatalf("update second failed: %v", err) + } + + firstPath := filepath.Join(store.DetailDir(), first.ID+".json") + secondPath := filepath.Join(store.DetailDir(), second.ID+".json") + beforeFirst, err := os.ReadFile(firstPath) + if err != nil { + t.Fatalf("read first detail before update failed: %v", err) + } + beforeSecond, err := os.ReadFile(secondPath) + if err != nil { + t.Fatalf("read second detail before update failed: %v", err) + } + + if _, err := store.Update(first.ID, UpdateParams{Status: "success", Content: "first-updated", Completed: true}); err != nil { + t.Fatalf("update first again failed: %v", err) + } + + afterFirst, err := os.ReadFile(firstPath) + if err != nil { + t.Fatalf("read first detail after update failed: %v", err) + } + afterSecond, err := os.ReadFile(secondPath) + if err != nil { + t.Fatalf("read second detail after update failed: %v", err) + } + + if bytes.Equal(beforeFirst, afterFirst) { + t.Fatalf("expected first detail file to change after update") + } + if !bytes.Equal(beforeSecond, afterSecond) { + t.Fatalf("expected untouched detail file to remain byte-identical") + } +} diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go index 11dc291..0af4746 100644 --- a/internal/sse/consumer.go +++ b/internal/sse/consumer.go @@ -36,9 +36,13 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co currentType = "thinking" } _ = deepseek.ScanSSELines(resp, func(line []byte) bool { - if chunk, done, parsed := ParseDeepSeekSSELine(line); parsed && !done { + chunk, done, parsed := ParseDeepSeekSSELine(line) + if parsed && !done { collector.ingestChunk(chunk) } + if done { + return false + } if stopped { return true } @@ -52,7 +56,8 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co contentFilter = true } // Keep scanning to collect late-arriving citation metadata lines - // that can appear after response/status=FINISHED. + // that can appear after response/status=FINISHED, but stop as soon + // as [DONE] arrives. stopped = true return true } diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go index 9e751c7..99679c5 100644 --- a/internal/sse/consumer_edge_test.go +++ b/internal/sse/consumer_edge_test.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" "testing" + "time" ) // ─── CollectStream edge cases ──────────────────────────────────────── @@ -227,6 +228,39 @@ func TestCollectStreamStatusFinished(t *testing.T) { } } +func TestCollectStreamStopsOnDoneAfterFinished(t *testing.T) { + pr, pw := io.Pipe() + defer func() { _ = pw.Close() }() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: pr, + } + + resultCh := make(chan CollectResult, 1) + go func() { + resultCh <- CollectStream(resp, false, false) + }() + + _, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n") + _, _ = io.WriteString(pw, "data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n") + _, _ = io.WriteString(pw, "data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":1}]}\n") + _, _ = io.WriteString(pw, "data: [DONE]\n") + + select { + case result := <-resultCh: + if result.Text != "Hello" { + t.Fatalf("expected text to freeze at FINISHED, got %q", result.Text) + } + if got := result.CitationLinks[1]; got != "https://example.com/a" { + t.Fatalf("expected citation metadata after FINISHED, got %q", got) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("CollectStream did not stop on [DONE] after FINISHED") + } +} + func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) { resp := makeHTTPResponse( "data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" +