mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-08 02:15:27 +08:00
增加“对话记录”
This commit is contained in:
268
internal/adapter/openai/chat_history.go
Normal file
268
internal/adapter/openai/chat_history.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/prompt"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
const adminWebUISourceHeader = "X-Ds2-Source"
|
||||
const adminWebUISourceValue = "admin-webui-api-tester"
|
||||
|
||||
type chatHistorySession struct {
|
||||
store *chathistory.Store
|
||||
entryID string
|
||||
startedAt time.Time
|
||||
lastPersist time.Time
|
||||
finalPrompt string
|
||||
startParams chathistory.StartParams
|
||||
disabled bool
|
||||
}
|
||||
|
||||
func startChatHistory(store *chathistory.Store, r *http.Request, a *auth.RequestAuth, stdReq util.StandardRequest) *chatHistorySession {
|
||||
if store == nil || r == nil || a == nil {
|
||||
return nil
|
||||
}
|
||||
if !store.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if !shouldCaptureChatHistory(r) {
|
||||
return nil
|
||||
}
|
||||
entry, err := store.Start(chathistory.StartParams{
|
||||
CallerID: strings.TrimSpace(a.CallerID),
|
||||
AccountID: strings.TrimSpace(a.AccountID),
|
||||
Model: strings.TrimSpace(stdReq.ResponseModel),
|
||||
Stream: stdReq.Stream,
|
||||
UserInput: extractSingleUserInput(stdReq.Messages),
|
||||
Messages: extractAllMessages(stdReq.Messages),
|
||||
FinalPrompt: stdReq.FinalPrompt,
|
||||
})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[chat_history] start failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
startParams := chathistory.StartParams{
|
||||
CallerID: strings.TrimSpace(a.CallerID),
|
||||
AccountID: strings.TrimSpace(a.AccountID),
|
||||
Model: strings.TrimSpace(stdReq.ResponseModel),
|
||||
Stream: stdReq.Stream,
|
||||
UserInput: extractSingleUserInput(stdReq.Messages),
|
||||
Messages: extractAllMessages(stdReq.Messages),
|
||||
FinalPrompt: stdReq.FinalPrompt,
|
||||
}
|
||||
return &chatHistorySession{
|
||||
store: store,
|
||||
entryID: entry.ID,
|
||||
startedAt: time.Now(),
|
||||
lastPersist: time.Now(),
|
||||
finalPrompt: stdReq.FinalPrompt,
|
||||
startParams: startParams,
|
||||
}
|
||||
}
|
||||
|
||||
func shouldCaptureChatHistory(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) || isVercelStreamReleaseRequest(r) {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.Header.Get(adminWebUISourceHeader)) != adminWebUISourceValue
|
||||
}
|
||||
|
||||
func extractSingleUserInput(messages []any) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
if role != "user" {
|
||||
continue
|
||||
}
|
||||
if normalized := strings.TrimSpace(prompt.NormalizeContent(msg["content"])); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractAllMessages(messages []any) []chathistory.Message {
|
||||
out := make([]chathistory.Message, 0, len(messages))
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
content := strings.TrimSpace(prompt.NormalizeContent(msg["content"]))
|
||||
if role == "" || content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, chathistory.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) progress(thinking, content string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
if now.Sub(s.lastPersist) < 250*time.Millisecond {
|
||||
return
|
||||
}
|
||||
s.lastPersist = now
|
||||
if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "streaming",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
}); err != nil {
|
||||
if !s.retryMissingEntry() {
|
||||
s.disableOnMissing(err)
|
||||
return
|
||||
}
|
||||
_, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "streaming",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
})
|
||||
s.disableOnMissing(retryErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) success(statusCode int, thinking, content, finishReason string, usage map[string]any) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
Completed: true,
|
||||
}); err != nil {
|
||||
if !s.retryMissingEntry() {
|
||||
s.disableOnMissing(err)
|
||||
return
|
||||
}
|
||||
_, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
Completed: true,
|
||||
})
|
||||
s.disableOnMissing(retryErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) error(statusCode int, message, finishReason, thinking, content string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "error",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
Error: message,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Completed: true,
|
||||
}); err != nil {
|
||||
if !s.retryMissingEntry() {
|
||||
s.disableOnMissing(err)
|
||||
return
|
||||
}
|
||||
_, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "error",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
Error: message,
|
||||
StatusCode: statusCode,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Completed: true,
|
||||
})
|
||||
s.disableOnMissing(retryErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) stopped(thinking, content, finishReason string) {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return
|
||||
}
|
||||
if _, err := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "stopped",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: openaifmt.BuildChatUsage(s.finalPrompt, thinking, content),
|
||||
Completed: true,
|
||||
}); err != nil {
|
||||
if !s.retryMissingEntry() {
|
||||
s.disableOnMissing(err)
|
||||
return
|
||||
}
|
||||
_, retryErr := s.store.Update(s.entryID, chathistory.UpdateParams{
|
||||
Status: "stopped",
|
||||
ReasoningContent: thinking,
|
||||
Content: content,
|
||||
StatusCode: http.StatusOK,
|
||||
ElapsedMs: time.Since(s.startedAt).Milliseconds(),
|
||||
FinishReason: finishReason,
|
||||
Usage: openaifmt.BuildChatUsage(s.finalPrompt, thinking, content),
|
||||
Completed: true,
|
||||
})
|
||||
s.disableOnMissing(retryErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) retryMissingEntry() bool {
|
||||
if s == nil || s.store == nil || s.disabled {
|
||||
return false
|
||||
}
|
||||
entry, err := s.store.Start(s.startParams)
|
||||
if err != nil {
|
||||
s.disableOnMissing(err)
|
||||
return false
|
||||
}
|
||||
s.entryID = entry.ID
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *chatHistorySession) disableOnMissing(err error) {
|
||||
if err == nil || s == nil {
|
||||
return
|
||||
}
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
s.disabled = true
|
||||
return
|
||||
}
|
||||
config.Logger.Warn("[chat_history] update disabled", "error", err)
|
||||
s.disabled = true
|
||||
}
|
||||
174
internal/adapter/openai/chat_history_test.go
Normal file
174
internal/adapter/openai/chat_history_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
func newTestChatHistoryStore(t *testing.T) *chathistory.Store {
|
||||
t.Helper()
|
||||
store := chathistory.New(filepath.Join(t.TempDir(), "chat_history.json"))
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("chat history store unavailable: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestChatCompletionsNonStreamPersistsHistory(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"system","content":"be precise"},{"role":"user","content":"hi there"},{"role":"assistant","content":"previous answer"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
item := snapshot.Items[0]
|
||||
if item.Status != "success" || item.UserInput != "hi there" {
|
||||
t.Fatalf("unexpected persisted history summary: %#v", item)
|
||||
}
|
||||
full, err := historyStore.Get(item.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected detail item, got %v", err)
|
||||
}
|
||||
if full.Content != "hello world" {
|
||||
t.Fatalf("expected detail content persisted, got %#v", full)
|
||||
}
|
||||
if len(full.Messages) != 3 {
|
||||
t.Fatalf("expected all request messages persisted, got %#v", full.Messages)
|
||||
}
|
||||
if full.FinalPrompt == "" {
|
||||
t.Fatalf("expected final prompt to be persisted")
|
||||
}
|
||||
if item.CallerID != "caller:test" {
|
||||
t.Fatalf("expected caller hash persisted in summary, got %#v", item.CallerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
entry, err := historyStore.Start(chathistory.StartParams{
|
||||
CallerID: "caller:test",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start history failed: %v", err)
|
||||
}
|
||||
session := &chatHistorySession{
|
||||
store: historyStore,
|
||||
entryID: entry.ID,
|
||||
startedAt: time.Now(),
|
||||
lastPersist: time.Now(),
|
||||
finalPrompt: "hello",
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil).WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-stop", "deepseek-chat", "prompt", false, false, nil, session)
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
full, err := historyStore.Get(snapshot.Items[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get detail failed: %v", err)
|
||||
}
|
||||
if full.Status != "stopped" {
|
||||
t.Fatalf("expected stopped status, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsSkipsAdminWebUISource(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi there"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(adminWebUISourceHeader, adminWebUISourceValue)
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected admin webui source to be skipped, got %#v", snapshot.Items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsSkipsHistoryWhenDisabled(t *testing.T) {
|
||||
historyStore := newTestChatHistoryStore(t)
|
||||
if _, err := historyStore.SetLimit(chathistory.DisabledLimit); err != nil {
|
||||
t.Fatalf("disable history store failed: %v", err)
|
||||
}
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello world"}`, `data: [DONE]`)},
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi there"}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ChatCompletions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected disabled history to stay empty, got %#v", snapshot.Items)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,14 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
finalThinking string
|
||||
finalText string
|
||||
finalFinishReason string
|
||||
finalUsage map[string]any
|
||||
finalErrorStatus int
|
||||
finalErrorMessage string
|
||||
finalErrorCode string
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -99,6 +107,9 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
|
||||
s.finalErrorStatus = status
|
||||
s.finalErrorMessage = message
|
||||
s.finalErrorCode = code
|
||||
s.sendChunk(map[string]any{
|
||||
"status_code": status,
|
||||
"error": map[string]any{
|
||||
@@ -114,6 +125,8 @@ func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
s.finalThinking = finalThinking
|
||||
s.finalText = finalText
|
||||
detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
@@ -197,6 +210,8 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
return
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
s.finalFinishReason = finishReason
|
||||
s.finalUsage = usage
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
|
||||
@@ -63,32 +63,45 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
historySession := startChatHistory(h.ChatHistory, r, a, stdReq)
|
||||
|
||||
sessionID, err = h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
if historySession != nil {
|
||||
historySession.error(http.StatusInternalServerError, "Failed to get completion.", "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
h.handleNonStream(w, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -124,14 +137,16 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.error(resp.StatusCode, string(body), "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
stripReferenceMarkers := h.compatStripReferenceMarkers()
|
||||
@@ -140,17 +155,34 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
if searchEnabled {
|
||||
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
|
||||
}
|
||||
if writeUpstreamEmptyOutputError(w, finalText, result.ContentFilter) {
|
||||
if shouldWriteUpstreamEmptyOutputError(finalText, result.ContentFilter) {
|
||||
status, message, code := upstreamEmptyOutputDetail(result.ContentFilter, finalText, finalThinking)
|
||||
if historySession != nil {
|
||||
historySession.error(status, message, code, finalThinking, finalText)
|
||||
}
|
||||
writeUpstreamEmptyOutputError(w, finalText, result.ContentFilter)
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
|
||||
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
|
||||
finishReason = fr
|
||||
}
|
||||
}
|
||||
if historySession != nil {
|
||||
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText))
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.error(resp.StatusCode, string(body), "error", "", "")
|
||||
}
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
@@ -201,13 +233,32 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnParsed: func(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
decision := streamRuntime.onParsed(parsed)
|
||||
if historySession != nil {
|
||||
historySession.progress(streamRuntime.thinking.String(), streamRuntime.text.String())
|
||||
}
|
||||
return decision
|
||||
},
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
} else {
|
||||
streamRuntime.finalize("stop")
|
||||
}
|
||||
if historySession == nil {
|
||||
return
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
if streamRuntime.finalErrorMessage != "" {
|
||||
historySession.error(streamRuntime.finalErrorStatus, streamRuntime.finalErrorMessage, streamRuntime.finalErrorCode, streamRuntime.thinking.String(), streamRuntime.text.String())
|
||||
return
|
||||
}
|
||||
historySession.success(http.StatusOK, streamRuntime.finalThinking, streamRuntime.finalText, streamRuntime.finalFinishReason, streamRuntime.finalUsage)
|
||||
},
|
||||
OnContextDone: func() {
|
||||
if historySession != nil {
|
||||
historySession.stopped(streamRuntime.thinking.String(), streamRuntime.text.String(), string(streamengine.StopReasonContextCancelled))
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
@@ -25,9 +26,10 @@ const (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
ChatHistory *chathistory.Store
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -94,7 +93,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty", "deepseek-chat", "prompt", false, false, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -113,7 +112,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, false, nil, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -132,7 +131,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, false, nil)
|
||||
h.handleNonStream(rec, resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, false, nil, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
@@ -153,7 +152,7 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid6", "deepseek-chat", "prompt", false, false, []string{"search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
@@ -190,7 +189,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleStream(rec, req, resp, "cid10", "deepseek-chat", "prompt", false, false, []string{"search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
|
||||
@@ -2,14 +2,26 @@ package openai
|
||||
|
||||
import "net/http"
|
||||
|
||||
func shouldWriteUpstreamEmptyOutputError(text string, contentFilter bool) bool {
|
||||
return text == ""
|
||||
}
|
||||
|
||||
func upstreamEmptyOutputDetail(contentFilter bool, text, thinking string) (int, string, string) {
|
||||
_ = text
|
||||
if contentFilter {
|
||||
return http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter"
|
||||
}
|
||||
if thinking != "" {
|
||||
return http.StatusTooManyRequests, "Upstream model returned reasoning without visible output.", "upstream_empty_output"
|
||||
}
|
||||
return http.StatusTooManyRequests, "Upstream model returned empty output.", "upstream_empty_output"
|
||||
}
|
||||
|
||||
func writeUpstreamEmptyOutputError(w http.ResponseWriter, text string, contentFilter bool) bool {
|
||||
if text != "" {
|
||||
if !shouldWriteUpstreamEmptyOutputError(text, contentFilter) {
|
||||
return false
|
||||
}
|
||||
if contentFilter {
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter")
|
||||
return true
|
||||
}
|
||||
writeOpenAIErrorWithCode(w, http.StatusTooManyRequests, "Upstream model returned empty output.", "upstream_empty_output")
|
||||
status, message, code := upstreamEmptyOutputDetail(contentFilter, text, "")
|
||||
writeOpenAIErrorWithCode(w, status, message, code)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package admin
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatCaller
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatCaller
|
||||
ChatHistory *chathistory.Store
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
@@ -50,6 +53,11 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Get("/export", h.exportConfig)
|
||||
pr.Get("/dev/captures", h.getDevCaptures)
|
||||
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||
pr.Get("/chat-history", h.getChatHistory)
|
||||
pr.Get("/chat-history/{id}", h.getChatHistoryItem)
|
||||
pr.Delete("/chat-history", h.clearChatHistory)
|
||||
pr.Delete("/chat-history/{id}", h.deleteChatHistoryItem)
|
||||
pr.Put("/chat-history/settings", h.updateChatHistorySettings)
|
||||
pr.Get("/version", h.getVersion)
|
||||
})
|
||||
}
|
||||
|
||||
134
internal/admin/handler_chat_history.go
Normal file
134
internal/admin/handler_chat_history.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
)
|
||||
|
||||
func (h *Handler) getChatHistory(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{
|
||||
"detail": err.Error(),
|
||||
"path": store.Path(),
|
||||
})
|
||||
return
|
||||
}
|
||||
etag := chathistory.ListETag(snapshot.Revision)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if strings.TrimSpace(r.Header.Get("If-None-Match")) == etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"version": snapshot.Version,
|
||||
"limit": snapshot.Limit,
|
||||
"revision": snapshot.Revision,
|
||||
"items": snapshot.Items,
|
||||
"path": store.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) getChatHistoryItem(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(chi.URLParam(r, "id"))
|
||||
if id == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "history id is required"})
|
||||
return
|
||||
}
|
||||
item, err := store.Get(id)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
etag := chathistory.DetailETag(item.ID, item.Revision)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
if strings.TrimSpace(r.Header.Get("If-None-Match")) == etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"item": item,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) clearChatHistory(w http.ResponseWriter, _ *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
if err := store.Clear(); err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": err.Error(), "path": store.Path()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteChatHistoryItem(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(chi.URLParam(r, "id"))
|
||||
if id == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "history id is required"})
|
||||
return
|
||||
}
|
||||
if err := store.Delete(id); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) updateChatHistorySettings(w http.ResponseWriter, r *http.Request) {
|
||||
store := h.ChatHistory
|
||||
if store == nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"detail": "chat history store is not configured"})
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
snapshot, err := store.SetLimit(body.Limit)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"limit": snapshot.Limit,
|
||||
"revision": snapshot.Revision,
|
||||
"items": snapshot.Items,
|
||||
})
|
||||
}
|
||||
176
internal/admin/handler_chat_history_test.go
Normal file
176
internal/admin/handler_chat_history_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newChatHistoryAdminHarness(t *testing.T) (*Handler, *chathistory.Store) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("write config failed: %v", err)
|
||||
}
|
||||
t.Setenv("DS2API_CONFIG_PATH", configPath)
|
||||
t.Setenv("DS2API_ADMIN_KEY", "admin")
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
store, err := config.LoadStoreWithError()
|
||||
if err != nil {
|
||||
t.Fatalf("load config store failed: %v", err)
|
||||
}
|
||||
historyStore := chathistory.New(filepath.Join(dir, "chat_history.json"))
|
||||
return &Handler{Store: store, ChatHistory: historyStore}, historyStore
|
||||
}
|
||||
|
||||
func TestGetChatHistoryAndUpdateSettings(t *testing.T) {
|
||||
h, historyStore := newChatHistoryAdminHarness(t)
|
||||
entry, err := historyStore.Start(chathistory.StartParams{
|
||||
CallerID: "caller:test",
|
||||
AccountID: "user@example.com",
|
||||
Model: "deepseek-chat",
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start history failed: %v", err)
|
||||
}
|
||||
if _, err := historyStore.Update(entry.ID, chathistory.UpdateParams{
|
||||
Status: "success",
|
||||
Content: "world",
|
||||
Completed: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("update history failed: %v", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/chat-history", nil)
|
||||
req.Header.Set("Authorization", "Bearer admin")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode payload failed: %v", err)
|
||||
}
|
||||
items, _ := payload["items"].([]any)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected one history item, got %#v", payload)
|
||||
}
|
||||
if rec.Header().Get("ETag") == "" {
|
||||
t.Fatalf("expected list etag header")
|
||||
}
|
||||
|
||||
notModifiedReq := httptest.NewRequest(http.MethodGet, "/chat-history", nil)
|
||||
notModifiedReq.Header.Set("Authorization", "Bearer admin")
|
||||
notModifiedReq.Header.Set("If-None-Match", rec.Header().Get("ETag"))
|
||||
notModifiedRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(notModifiedRec, notModifiedReq)
|
||||
if notModifiedRec.Code != http.StatusNotModified {
|
||||
t.Fatalf("expected 304, got %d body=%s", notModifiedRec.Code, notModifiedRec.Body.String())
|
||||
}
|
||||
|
||||
itemReq := httptest.NewRequest(http.MethodGet, "/chat-history/"+entry.ID, nil)
|
||||
itemReq.Header.Set("Authorization", "Bearer admin")
|
||||
itemRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(itemRec, itemReq)
|
||||
if itemRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected item 200, got %d body=%s", itemRec.Code, itemRec.Body.String())
|
||||
}
|
||||
if itemRec.Header().Get("ETag") == "" {
|
||||
t.Fatalf("expected detail etag header")
|
||||
}
|
||||
|
||||
updateReq := httptest.NewRequest(http.MethodPut, "/chat-history/settings", bytes.NewReader([]byte(`{"limit":10}`)))
|
||||
updateReq.Header.Set("Authorization", "Bearer admin")
|
||||
updateRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(updateRec, updateReq)
|
||||
if updateRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 from settings update, got %d body=%s", updateRec.Code, updateRec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != 10 {
|
||||
t.Fatalf("expected limit=10, got %d", snapshot.Limit)
|
||||
}
|
||||
|
||||
disableReq := httptest.NewRequest(http.MethodPut, "/chat-history/settings", bytes.NewReader([]byte(`{"limit":0}`)))
|
||||
disableReq.Header.Set("Authorization", "Bearer admin")
|
||||
disableRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(disableRec, disableReq)
|
||||
if disableRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 from disable update, got %d body=%s", disableRec.Code, disableRec.Body.String())
|
||||
}
|
||||
snapshot, err = historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot after disable failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != chathistory.DisabledLimit {
|
||||
t.Fatalf("expected limit=0, got %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected history preserved when disabled, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAndClearChatHistory(t *testing.T) {
|
||||
h, historyStore := newChatHistoryAdminHarness(t)
|
||||
entryA, err := historyStore.Start(chathistory.StartParams{UserInput: "a"})
|
||||
if err != nil {
|
||||
t.Fatalf("start A failed: %v", err)
|
||||
}
|
||||
if _, err := historyStore.Start(chathistory.StartParams{UserInput: "b"}); err != nil {
|
||||
t.Fatalf("start B failed: %v", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/chat-history/"+entryA.ID, nil)
|
||||
deleteReq.Header.Set("Authorization", "Bearer admin")
|
||||
deleteRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(deleteRec, deleteReq)
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete 200, got %d body=%s", deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one item after delete, got %d", len(snapshot.Items))
|
||||
}
|
||||
|
||||
clearReq := httptest.NewRequest(http.MethodDelete, "/chat-history", nil)
|
||||
clearReq.Header.Set("Authorization", "Bearer admin")
|
||||
clearRec := httptest.NewRecorder()
|
||||
r.ServeHTTP(clearRec, clearReq)
|
||||
if clearRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected clear 200, got %d body=%s", clearRec.Code, clearRec.Body.String())
|
||||
}
|
||||
|
||||
snapshot, err = historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected empty items after clear, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
711
internal/chathistory/store.go
Normal file
711
internal/chathistory/store.go
Normal file
@@ -0,0 +1,711 @@
|
||||
package chathistory
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
FileVersion = 2
|
||||
DisabledLimit = 0
|
||||
DefaultLimit = 20
|
||||
MaxLimit = 50
|
||||
defaultPreviewAt = 160
|
||||
)
|
||||
|
||||
var allowedLimits = map[int]struct{}{
|
||||
DisabledLimit: {},
|
||||
10: {},
|
||||
20: {},
|
||||
50: {},
|
||||
}
|
||||
|
||||
var ErrDisabled = errors.New("chat history disabled")
|
||||
|
||||
type Entry struct {
|
||||
ID string `json:"id"`
|
||||
Revision int64 `json:"revision"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CallerID string `json:"caller_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserInput string `json:"user_input,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
FinalPrompt string `json:"final_prompt,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage map[string]any `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type SummaryEntry struct {
|
||||
ID string `json:"id"`
|
||||
Revision int64 `json:"revision"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CallerID string `json:"caller_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserInput string `json:"user_input,omitempty"`
|
||||
Preview string `json:"preview,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
DetailRevision int64 `json:"detail_revision"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Version int `json:"version"`
|
||||
Limit int `json:"limit"`
|
||||
Revision int64 `json:"revision"`
|
||||
Items []SummaryEntry `json:"items"`
|
||||
}
|
||||
|
||||
type StartParams struct {
|
||||
CallerID string
|
||||
AccountID string
|
||||
Model string
|
||||
Stream bool
|
||||
UserInput string
|
||||
Messages []Message
|
||||
FinalPrompt string
|
||||
}
|
||||
|
||||
type UpdateParams struct {
|
||||
Status string
|
||||
ReasoningContent string
|
||||
Content string
|
||||
Error string
|
||||
StatusCode int
|
||||
ElapsedMs int64
|
||||
FinishReason string
|
||||
Usage map[string]any
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type detailEnvelope struct {
|
||||
Version int `json:"version"`
|
||||
Item Entry `json:"item"`
|
||||
}
|
||||
|
||||
type legacyFile struct {
|
||||
Version int `json:"version"`
|
||||
Limit int `json:"limit"`
|
||||
Items []Entry `json:"items"`
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
detailDir string
|
||||
state File
|
||||
details map[string]Entry
|
||||
err error
|
||||
}
|
||||
|
||||
func New(path string) *Store {
|
||||
s := &Store{
|
||||
path: strings.TrimSpace(path),
|
||||
detailDir: strings.TrimSpace(path) + ".d",
|
||||
state: File{
|
||||
Version: FileVersion,
|
||||
Limit: DefaultLimit,
|
||||
Revision: 0,
|
||||
Items: []SummaryEntry{},
|
||||
},
|
||||
details: map[string]Entry{},
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.err = s.loadLocked()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Store) Path() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.path
|
||||
}
|
||||
|
||||
func (s *Store) DetailDir() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.detailDir
|
||||
}
|
||||
|
||||
func (s *Store) Err() error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() (File, error) {
|
||||
if s == nil {
|
||||
return File{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return File{}, s.err
|
||||
}
|
||||
return cloneFile(s.state), nil
|
||||
}
|
||||
|
||||
func (s *Store) Enabled() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
return s.state.Limit != DisabledLimit
|
||||
}
|
||||
|
||||
func (s *Store) Get(id string) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
item, ok := s.details[strings.TrimSpace(id)]
|
||||
if !ok {
|
||||
return Entry{}, errors.New("chat history entry not found")
|
||||
}
|
||||
return cloneEntry(item), nil
|
||||
}
|
||||
|
||||
func (s *Store) Start(params StartParams) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
if s.state.Limit == DisabledLimit {
|
||||
return Entry{}, ErrDisabled
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
revision := s.nextRevisionLocked()
|
||||
entry := Entry{
|
||||
ID: "chat_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
Revision: revision,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Status: "streaming",
|
||||
CallerID: strings.TrimSpace(params.CallerID),
|
||||
AccountID: strings.TrimSpace(params.AccountID),
|
||||
Model: strings.TrimSpace(params.Model),
|
||||
Stream: params.Stream,
|
||||
UserInput: strings.TrimSpace(params.UserInput),
|
||||
Messages: cloneMessages(params.Messages),
|
||||
FinalPrompt: strings.TrimSpace(params.FinalPrompt),
|
||||
}
|
||||
s.details[entry.ID] = entry
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
return cloneEntry(entry), nil
|
||||
}
|
||||
|
||||
func (s *Store) Update(id string, params UpdateParams) (Entry, error) {
|
||||
if s == nil {
|
||||
return Entry{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Entry{}, s.err
|
||||
}
|
||||
target := strings.TrimSpace(id)
|
||||
if target == "" {
|
||||
return Entry{}, errors.New("history id is required")
|
||||
}
|
||||
item, ok := s.details[target]
|
||||
if !ok {
|
||||
return Entry{}, errors.New("chat history entry not found")
|
||||
}
|
||||
now := time.Now().UnixMilli()
|
||||
item.Revision = s.nextRevisionLocked()
|
||||
item.UpdatedAt = now
|
||||
if params.Status != "" {
|
||||
item.Status = params.Status
|
||||
}
|
||||
item.ReasoningContent = params.ReasoningContent
|
||||
item.Content = params.Content
|
||||
item.Error = strings.TrimSpace(params.Error)
|
||||
item.StatusCode = params.StatusCode
|
||||
item.ElapsedMs = params.ElapsedMs
|
||||
item.FinishReason = strings.TrimSpace(params.FinishReason)
|
||||
if params.Usage != nil {
|
||||
item.Usage = cloneMap(params.Usage)
|
||||
}
|
||||
if params.Completed {
|
||||
item.CompletedAt = now
|
||||
}
|
||||
s.details[target] = item
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
return cloneEntry(item), nil
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id string) error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
target := strings.TrimSpace(id)
|
||||
if target == "" {
|
||||
return errors.New("history id is required")
|
||||
}
|
||||
if _, ok := s.details[target]; !ok {
|
||||
return errors.New("chat history entry not found")
|
||||
}
|
||||
delete(s.details, target)
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Clear() error {
|
||||
if s == nil {
|
||||
return errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
s.details = map[string]Entry{}
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) SetLimit(limit int) (File, error) {
|
||||
if s == nil {
|
||||
return File{}, errors.New("chat history store is nil")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return File{}, s.err
|
||||
}
|
||||
if !isAllowedLimit(limit) {
|
||||
return File{}, fmt.Errorf("unsupported chat history limit: %d", limit)
|
||||
}
|
||||
s.state.Limit = limit
|
||||
s.nextRevisionLocked()
|
||||
s.rebuildIndexLocked()
|
||||
if err := s.saveLocked(); err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return cloneFile(s.state), nil
|
||||
}
|
||||
|
||||
func (s *Store) loadLocked() error {
|
||||
if strings.TrimSpace(s.path) == "" {
|
||||
return errors.New("chat history path is required")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil && filepath.Dir(s.path) != "." {
|
||||
return fmt.Errorf("create chat history dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(s.detailDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history detail dir: %w", err)
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return s.saveLocked()
|
||||
}
|
||||
return fmt.Errorf("read chat history index: %w", err)
|
||||
}
|
||||
|
||||
legacy, legacyOK, legacyErr := parseLegacy(raw)
|
||||
if legacyErr != nil {
|
||||
return legacyErr
|
||||
}
|
||||
if legacyOK && !hasDetailFiles(s.detailDir) {
|
||||
s.loadLegacyLocked(legacy)
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
var state File
|
||||
if err := json.Unmarshal(raw, &state); err != nil {
|
||||
return fmt.Errorf("decode chat history index: %w", err)
|
||||
}
|
||||
if state.Version == 0 {
|
||||
state.Version = FileVersion
|
||||
}
|
||||
if !isAllowedLimit(state.Limit) {
|
||||
state.Limit = DefaultLimit
|
||||
}
|
||||
s.state = cloneFile(state)
|
||||
s.details = map[string]Entry{}
|
||||
for _, item := range state.Items {
|
||||
detail, err := readDetailFile(filepath.Join(s.detailDir, item.ID+".json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.details[item.ID] = detail
|
||||
}
|
||||
s.rebuildIndexLocked()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) loadLegacyLocked(legacy legacyFile) {
|
||||
s.state.Version = FileVersion
|
||||
s.state.Limit = legacy.Limit
|
||||
if !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
s.details = map[string]Entry{}
|
||||
maxRevision := int64(0)
|
||||
for _, item := range legacy.Items {
|
||||
if strings.TrimSpace(item.ID) == "" {
|
||||
continue
|
||||
}
|
||||
item.Messages = cloneMessages(item.Messages)
|
||||
if item.Revision == 0 {
|
||||
if item.UpdatedAt > 0 {
|
||||
item.Revision = item.UpdatedAt
|
||||
} else {
|
||||
item.Revision = time.Now().UnixNano()
|
||||
}
|
||||
}
|
||||
if item.Revision > maxRevision {
|
||||
maxRevision = item.Revision
|
||||
}
|
||||
s.details[item.ID] = item
|
||||
}
|
||||
s.state.Revision = maxRevision
|
||||
s.rebuildIndexLocked()
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
s.state.Version = FileVersion
|
||||
if !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
s.rebuildIndexLocked()
|
||||
|
||||
if err := os.MkdirAll(s.detailDir, 0o755); err != nil {
|
||||
s.err = err
|
||||
return err
|
||||
}
|
||||
activeFiles := make(map[string]struct{}, len(s.details))
|
||||
for id, item := range s.details {
|
||||
path := filepath.Join(s.detailDir, id+".json")
|
||||
activeFiles[path] = struct{}{}
|
||||
payload, err := json.MarshalIndent(detailEnvelope{
|
||||
Version: FileVersion,
|
||||
Item: item,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return err
|
||||
}
|
||||
if err := writeFileAtomic(path, append(payload, '\n')); err != nil {
|
||||
s.err = err
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := cleanupDetailDir(s.detailDir, activeFiles); err != nil {
|
||||
s.err = err
|
||||
return err
|
||||
}
|
||||
|
||||
payload, err := json.MarshalIndent(s.state, "", " ")
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return err
|
||||
}
|
||||
if err := writeFileAtomic(s.path, append(payload, '\n')); err != nil {
|
||||
s.err = err
|
||||
return err
|
||||
}
|
||||
s.err = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) rebuildIndexLocked() {
|
||||
summaries := make([]SummaryEntry, 0, len(s.details))
|
||||
for _, item := range s.details {
|
||||
summaries = append(summaries, summaryFromEntry(item))
|
||||
}
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
if summaries[i].UpdatedAt == summaries[j].UpdatedAt {
|
||||
return summaries[i].CreatedAt > summaries[j].CreatedAt
|
||||
}
|
||||
return summaries[i].UpdatedAt > summaries[j].UpdatedAt
|
||||
})
|
||||
if s.state.Limit < DisabledLimit || !isAllowedLimit(s.state.Limit) {
|
||||
s.state.Limit = DefaultLimit
|
||||
}
|
||||
if s.state.Limit == DisabledLimit {
|
||||
s.state.Items = summaries
|
||||
return
|
||||
}
|
||||
if len(summaries) > s.state.Limit {
|
||||
keep := make(map[string]struct{}, s.state.Limit)
|
||||
for _, item := range summaries[:s.state.Limit] {
|
||||
keep[item.ID] = struct{}{}
|
||||
}
|
||||
for id := range s.details {
|
||||
if _, ok := keep[id]; !ok {
|
||||
delete(s.details, id)
|
||||
}
|
||||
}
|
||||
summaries = summaries[:s.state.Limit]
|
||||
}
|
||||
s.state.Items = summaries
|
||||
}
|
||||
|
||||
func (s *Store) nextRevisionLocked() int64 {
|
||||
next := time.Now().UnixNano()
|
||||
if next <= s.state.Revision {
|
||||
next = s.state.Revision + 1
|
||||
}
|
||||
s.state.Revision = next
|
||||
return next
|
||||
}
|
||||
|
||||
func summaryFromEntry(item Entry) SummaryEntry {
|
||||
return SummaryEntry{
|
||||
ID: item.ID,
|
||||
Revision: item.Revision,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
CompletedAt: item.CompletedAt,
|
||||
Status: item.Status,
|
||||
CallerID: item.CallerID,
|
||||
AccountID: item.AccountID,
|
||||
Model: item.Model,
|
||||
Stream: item.Stream,
|
||||
UserInput: item.UserInput,
|
||||
Preview: buildPreview(item),
|
||||
StatusCode: item.StatusCode,
|
||||
ElapsedMs: item.ElapsedMs,
|
||||
FinishReason: item.FinishReason,
|
||||
DetailRevision: item.Revision,
|
||||
}
|
||||
}
|
||||
|
||||
func buildPreview(item Entry) string {
|
||||
candidate := strings.TrimSpace(item.Content)
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.ReasoningContent)
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.Error)
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = strings.TrimSpace(item.UserInput)
|
||||
}
|
||||
if len(candidate) > defaultPreviewAt {
|
||||
return candidate[:defaultPreviewAt] + "..."
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
func readDetailFile(path string) (Entry, error) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Entry{}, fmt.Errorf("read chat history detail: %w", err)
|
||||
}
|
||||
var env detailEnvelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
return Entry{}, fmt.Errorf("decode chat history detail: %w", err)
|
||||
}
|
||||
return cloneEntry(env.Item), nil
|
||||
}
|
||||
|
||||
func hasDetailFiles(dir string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(entry.Name()), ".json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseLegacy(raw []byte) (legacyFile, bool, error) {
|
||||
var legacy legacyFile
|
||||
if err := json.Unmarshal(raw, &legacy); err != nil {
|
||||
return legacyFile{}, false, nil
|
||||
}
|
||||
if len(legacy.Items) == 0 {
|
||||
return legacy, false, nil
|
||||
}
|
||||
for _, item := range legacy.Items {
|
||||
if item.Content != "" || item.ReasoningContent != "" || item.FinalPrompt != "" || len(item.Messages) > 0 {
|
||||
return legacy, true, nil
|
||||
}
|
||||
}
|
||||
return legacy, false, nil
|
||||
}
|
||||
|
||||
func cleanupDetailDir(dir string, active map[string]struct{}) error {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chat history detail dir: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
if _, ok := active[path]; ok {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(path); err != nil {
|
||||
return fmt.Errorf("remove stale chat history detail: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeFileAtomic(path string, body []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if dir == "" {
|
||||
dir = "."
|
||||
}
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("create chat history dir: %w", err)
|
||||
}
|
||||
}
|
||||
tmpFile, err := os.CreateTemp(dir, ".chat-history-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp chat history: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := func() {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
if _, err := tmpFile.Write(body); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
cleanup()
|
||||
return fmt.Errorf("write temp chat history: %w", err)
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
cleanup()
|
||||
return fmt.Errorf("sync temp chat history: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("close temp chat history: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("promote temp chat history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListETag(revision int64) string {
|
||||
return fmt.Sprintf(`W/"chat-history-list-%d"`, revision)
|
||||
}
|
||||
|
||||
func DetailETag(id string, revision int64) string {
|
||||
return fmt.Sprintf(`W/"chat-history-detail-%s-%d"`, strings.TrimSpace(id), revision)
|
||||
}
|
||||
|
||||
func isAllowedLimit(limit int) bool {
|
||||
_, ok := allowedLimits[limit]
|
||||
return ok
|
||||
}
|
||||
|
||||
func cloneFile(in File) File {
|
||||
out := File{
|
||||
Version: in.Version,
|
||||
Limit: in.Limit,
|
||||
Revision: in.Revision,
|
||||
Items: make([]SummaryEntry, len(in.Items)),
|
||||
}
|
||||
copy(out.Items, in.Items)
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneEntry(item Entry) Entry {
|
||||
item.Usage = cloneMap(item.Usage)
|
||||
item.Messages = cloneMessages(item.Messages)
|
||||
return item
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneMessages(messages []Message) []Message {
|
||||
if len(messages) == 0 {
|
||||
return []Message{}
|
||||
}
|
||||
out := make([]Message, len(messages))
|
||||
copy(out, messages)
|
||||
return out
|
||||
}
|
||||
256
internal/chathistory/store_test.go
Normal file
256
internal/chathistory/store_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package chathistory
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStoreCreatesAndPersistsEntries(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
started, err := store.Start(StartParams{
|
||||
CallerID: "caller:abc",
|
||||
AccountID: "user@example.com",
|
||||
Model: "deepseek-chat",
|
||||
Stream: true,
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start entry failed: %v", err)
|
||||
}
|
||||
|
||||
updated, err := store.Update(started.ID, UpdateParams{
|
||||
Status: "success",
|
||||
ReasoningContent: "thinking",
|
||||
Content: "answer",
|
||||
StatusCode: 200,
|
||||
ElapsedMs: 321,
|
||||
FinishReason: "stop",
|
||||
Usage: map[string]any{"total_tokens": 9},
|
||||
Completed: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("update entry failed: %v", err)
|
||||
}
|
||||
if updated.Status != "success" || updated.Content != "answer" {
|
||||
t.Fatalf("unexpected updated entry: %#v", updated)
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != DefaultLimit {
|
||||
t.Fatalf("unexpected default limit: %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one item, got %d", len(snapshot.Items))
|
||||
}
|
||||
if snapshot.Items[0].CompletedAt == 0 {
|
||||
t.Fatalf("expected completed_at to be populated")
|
||||
}
|
||||
if snapshot.Items[0].Preview != "answer" {
|
||||
t.Fatalf("expected summary preview=answer, got %#v", snapshot.Items[0])
|
||||
}
|
||||
|
||||
reloaded := New(path)
|
||||
reloadedSnapshot, err := reloaded.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("reload snapshot failed: %v", err)
|
||||
}
|
||||
if len(reloadedSnapshot.Items) != 1 {
|
||||
t.Fatalf("unexpected reloaded summaries: %#v", reloadedSnapshot.Items)
|
||||
}
|
||||
full, err := reloaded.Get(started.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "answer" {
|
||||
t.Fatalf("expected detail content=answer, got %#v", full)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreTrimsToConfiguredLimit(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
if _, err := store.SetLimit(10); err != nil {
|
||||
t.Fatalf("set limit failed: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
entry, err := store.Start(StartParams{Model: "deepseek-chat", UserInput: "msg"})
|
||||
if err != nil {
|
||||
t.Fatalf("start %d failed: %v", i, err)
|
||||
}
|
||||
if _, err := store.Update(entry.ID, UpdateParams{Status: "success", Content: "ok", Completed: true}); err != nil {
|
||||
t.Fatalf("update %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 10 {
|
||||
t.Fatalf("expected 10 items, got %d", len(snapshot.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDeleteClearAndLimitValidation(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
entry, err := store.Start(StartParams{UserInput: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("start failed: %v", err)
|
||||
}
|
||||
if err := store.Delete(entry.ID); err != nil {
|
||||
t.Fatalf("delete failed: %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 0 {
|
||||
t.Fatalf("expected empty items after delete, got %d", len(snapshot.Items))
|
||||
}
|
||||
if _, err := store.SetLimit(999); err == nil {
|
||||
t.Fatalf("expected invalid limit error")
|
||||
}
|
||||
if err := store.Clear(); err != nil {
|
||||
t.Fatalf("clear failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDisablePreservesHistoryAndBlocksNewEntries(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
entry, err := store.Start(StartParams{UserInput: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("start failed: %v", err)
|
||||
}
|
||||
if _, err := store.Update(entry.ID, UpdateParams{Status: "success", Content: "world", Completed: true}); err != nil {
|
||||
t.Fatalf("update failed: %v", err)
|
||||
}
|
||||
|
||||
snapshot, err := store.SetLimit(DisabledLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("disable failed: %v", err)
|
||||
}
|
||||
if snapshot.Limit != DisabledLimit {
|
||||
t.Fatalf("expected disabled limit, got %d", snapshot.Limit)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected disabled mode to preserve summaries, got %d", len(snapshot.Items))
|
||||
}
|
||||
if store.Enabled() {
|
||||
t.Fatalf("expected store to report disabled")
|
||||
}
|
||||
if _, err := store.Start(StartParams{UserInput: "later"}); err != ErrDisabled {
|
||||
t.Fatalf("expected ErrDisabled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreConcurrentUpdatesKeepSplitFilesValid(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
store := New(path)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
entry, err := store.Start(StartParams{
|
||||
CallerID: "caller:test",
|
||||
Model: "deepseek-chat",
|
||||
UserInput: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("start failed: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = store.Update(entry.ID, UpdateParams{
|
||||
Status: "success",
|
||||
Content: "answer",
|
||||
ElapsedMs: int64(idx),
|
||||
Completed: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("update failed: %v", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 8 {
|
||||
t.Fatalf("expected 8 items, got %d", len(snapshot.Items))
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read index failed: %v", err)
|
||||
}
|
||||
var persisted File
|
||||
if err := json.Unmarshal(raw, &persisted); err != nil {
|
||||
t.Fatalf("persisted index is invalid json: %v", err)
|
||||
}
|
||||
if len(persisted.Items) != 8 {
|
||||
t.Fatalf("expected persisted items=8, got %d", len(persisted.Items))
|
||||
}
|
||||
|
||||
detailFiles, err := os.ReadDir(path + ".d")
|
||||
if err != nil {
|
||||
t.Fatalf("read detail dir failed: %v", err)
|
||||
}
|
||||
if len(detailFiles) != 8 {
|
||||
t.Fatalf("expected 8 detail files, got %d", len(detailFiles))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAutoMigratesLegacyMonolith(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "chat_history.json")
|
||||
legacy := legacyFile{
|
||||
Version: 1,
|
||||
Limit: 20,
|
||||
Items: []Entry{{
|
||||
ID: "chat_legacy",
|
||||
CreatedAt: 1,
|
||||
UpdatedAt: 2,
|
||||
Status: "success",
|
||||
UserInput: "hello",
|
||||
Content: "world",
|
||||
ReasoningContent: "thinking",
|
||||
}},
|
||||
}
|
||||
body, _ := json.MarshalIndent(legacy, "", " ")
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write legacy file failed: %v", err)
|
||||
}
|
||||
|
||||
store := New(path)
|
||||
if err := store.Err(); err != nil {
|
||||
t.Fatalf("expected legacy migration success, got %v", err)
|
||||
}
|
||||
snapshot, err := store.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot failed: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one migrated summary, got %#v", snapshot.Items)
|
||||
}
|
||||
full, err := store.Get("chat_legacy")
|
||||
if err != nil {
|
||||
t.Fatalf("get migrated detail failed: %v", err)
|
||||
}
|
||||
if full.Content != "world" {
|
||||
t.Fatalf("expected migrated detail content preserved, got %#v", full)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,10 @@ func RawStreamSampleRoot() string {
|
||||
return ResolvePath("DS2API_RAW_STREAM_SAMPLE_ROOT", "tests/raw_stream_samples")
|
||||
}
|
||||
|
||||
func ChatHistoryPath() string {
|
||||
return ResolvePath("DS2API_CHAT_HISTORY_PATH", "data/chat_history.json")
|
||||
}
|
||||
|
||||
func StaticAdminDir() string {
|
||||
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +20,7 @@ import (
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/admin"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/webui"
|
||||
@@ -46,17 +50,21 @@ func NewApp() (*App, error) {
|
||||
} else {
|
||||
config.Logger.Info("[PoW] pure Go solver ready")
|
||||
}
|
||||
chatHistoryStore := chathistory.New(config.ChatHistoryPath())
|
||||
if err := chatHistoryStore.Err(); err != nil {
|
||||
config.Logger.Warn("[chat_history] unavailable", "path", chatHistoryStore.Path(), "error", err)
|
||||
}
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient, ChatHistory: chatHistoryStore}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient, OpenAI: openaiHandler}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient, OpenAI: openaiHandler, ChatHistory: chatHistoryStore}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(filteredLogger())
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(cors)
|
||||
r.Use(timeout(0))
|
||||
@@ -99,11 +107,47 @@ func timeout(d time.Duration) func(http.Handler) http.Handler {
|
||||
return middleware.Timeout(d)
|
||||
}
|
||||
|
||||
func filteredLogger() func(http.Handler) http.Handler {
|
||||
color := true
|
||||
if isWindowsRuntime() {
|
||||
color = false
|
||||
}
|
||||
base := &middleware.DefaultLogFormatter{
|
||||
Logger: log.New(os.Stdout, "", log.LstdFlags),
|
||||
NoColor: !color,
|
||||
}
|
||||
return middleware.RequestLogger(&filteredLogFormatter{base: base})
|
||||
}
|
||||
|
||||
func isWindowsRuntime() bool {
|
||||
return runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
type filteredLogFormatter struct {
|
||||
base *middleware.DefaultLogFormatter
|
||||
}
|
||||
|
||||
func (f *filteredLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
|
||||
if r != nil && r.Method == http.MethodGet {
|
||||
path := strings.TrimSpace(r.URL.Path)
|
||||
if path == "/admin/chat-history" || strings.HasPrefix(path, "/admin/chat-history/") {
|
||||
return noopLogEntry{}
|
||||
}
|
||||
}
|
||||
return f.base.NewLogEntry(r)
|
||||
}
|
||||
|
||||
type noopLogEntry struct{}
|
||||
|
||||
func (noopLogEntry) Write(_ int, _ int, _ http.Header, _ time.Duration, _ interface{}) {}
|
||||
|
||||
func (noopLogEntry) Panic(_ interface{}, _ []byte) {}
|
||||
|
||||
func cors(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Ds2-Source, X-Vercel-Protection-Bypass")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user