diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 9c67492..b81378f 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -36,6 +36,8 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Post("/import", h.batchImport) pr.Post("/test", h.testAPI) pr.Post("/dev/raw-samples/capture", h.captureRawSample) + pr.Get("/dev/raw-samples/query", h.queryRawSampleCaptures) + pr.Post("/dev/raw-samples/save", h.saveRawSampleFromCaptures) pr.Post("/vercel/sync", h.syncVercel) pr.Get("/vercel/status", h.vercelStatus) pr.Post("/vercel/status", h.vercelStatus) diff --git a/internal/admin/handler_raw_samples.go b/internal/admin/handler_raw_samples.go index 76ce638..4bf85a8 100644 --- a/internal/admin/handler_raw_samples.go +++ b/internal/admin/handler_raw_samples.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "sort" "strings" "ds2api/internal/config" @@ -15,6 +16,11 @@ import ( "ds2api/internal/rawsample" ) +type captureChain struct { + Key string + Entries []devcapture.Entry +} + func (h *Handler) captureRawSample(w http.ResponseWriter, r *http.Request) { if h.OpenAI == nil { writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": "OpenAI handler is not configured"}) @@ -231,3 +237,310 @@ func cloneMap(in map[string]any) map[string]any { } return out } + +func (h *Handler) queryRawSampleCaptures(w http.ResponseWriter, r *http.Request) { + query := strings.TrimSpace(r.URL.Query().Get("q")) + limit := intFromQuery(r, "limit", 20) + if limit <= 0 { + limit = 20 + } + if limit > 50 { + limit = 50 + } + + chains := buildCaptureChains(devcapture.Global().Snapshot()) + items := make([]map[string]any, 0, len(chains)) + for _, chain := range chains { + if query != "" && !captureChainMatchesQuery(chain, query) { + continue + } + items = append(items, buildCaptureChainQueryItem(chain, query)) + if len(items) >= limit { + break + } + } + + writeJSON(w, http.StatusOK, map[string]any{ + "query": query, + "limit": limit, + "count": len(items), + "items": items, + }) +} + +func (h *Handler) saveRawSampleFromCaptures(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + snapshot := devcapture.Global().Snapshot() + if len(snapshot) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "no capture logs available"}) + return + } + + chain, err := resolveCaptureChainSelection(snapshot, req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + + sampleID := strings.TrimSpace(fieldString(req, "sample_id")) + source := strings.TrimSpace(fieldString(req, "source")) + if source == "" { + source = "admin/dev/raw-samples/save" + } + requestPayload := captureChainRequestPayload(chain) + + saved, err := rawsample.Persist(rawsample.PersistOptions{ + RootDir: config.RawStreamSampleRoot(), + SampleID: sampleID, + Source: source, + Request: requestPayload, + Capture: captureSummaryFromEntries(chain.Entries), + UpstreamBody: combineCaptureBodies(chain.Entries), + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "sample_id": saved.SampleID, + "sample_dir": saved.Dir, + "meta_path": saved.MetaPath, + "upstream_path": saved.UpstreamPath, + "chain_key": chain.Key, + "capture_ids": captureChainIDs(chain), + "round_count": len(chain.Entries), + }) +} + +func buildCaptureChains(snapshot []devcapture.Entry) []captureChain { + if len(snapshot) == 0 { + return nil + } + ordered := make([]devcapture.Entry, len(snapshot)) + copy(ordered, snapshot) + sort.SliceStable(ordered, func(i, j int) bool { + if ordered[i].CreatedAt == ordered[j].CreatedAt { + return ordered[i].ID < ordered[j].ID + } + return ordered[i].CreatedAt < ordered[j].CreatedAt + }) + + byKey := make(map[string]*captureChain, len(ordered)) + keys := make([]string, 0, len(ordered)) + for _, entry := range ordered { + key := captureChainKey(entry) + if key == "" { + key = "capture:" + entry.ID + } + if _, ok := byKey[key]; !ok { + byKey[key] = &captureChain{Key: key} + keys = append(keys, key) + } + byKey[key].Entries = append(byKey[key].Entries, entry) + } + + chains := make([]captureChain, 0, len(keys)) + for _, key := range keys { + chains = append(chains, *byKey[key]) + } + sort.SliceStable(chains, func(i, j int) bool { + return latestCreatedAt(chains[i]) > latestCreatedAt(chains[j]) + }) + return chains +} + +func captureChainKey(entry devcapture.Entry) string { + req := parseCaptureRequestBody(entry.RequestBody) + if sessionID := strings.TrimSpace(fieldString(req, "chat_session_id")); sessionID != "" { + return "session:" + sessionID + } + return "capture:" + entry.ID +} + +func parseCaptureRequestBody(raw string) map[string]any { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + var out map[string]any + if err := json.Unmarshal([]byte(raw), &out); err != nil { + return nil + } + return out +} + +func latestCreatedAt(chain captureChain) int64 { + var latest int64 + for _, entry := range chain.Entries { + if entry.CreatedAt > latest { + latest = entry.CreatedAt + } + } + return latest +} + +func captureChainMatchesQuery(chain captureChain, query string) bool { + query = strings.ToLower(strings.TrimSpace(query)) + if query == "" { + return true + } + for _, entry := range chain.Entries { + hay := strings.ToLower(strings.Join([]string{ + entry.Label, + entry.URL, + entry.AccountID, + entry.RequestBody, + entry.ResponseBody, + }, "\n")) + if strings.Contains(hay, query) { + return true + } + } + return false +} + +func buildCaptureChainQueryItem(chain captureChain, query string) map[string]any { + first := chain.Entries[0] + last := chain.Entries[len(chain.Entries)-1] + requestPreview := previewCaptureChainRequest(chain) + responsePreview := previewCaptureChainResponse(chain) + + return map[string]any{ + "chain_key": chain.Key, + "capture_ids": captureChainIDs(chain), + "created_at": latestCreatedAt(chain), + "round_count": len(chain.Entries), + "account_id": nilIfEmpty(strings.TrimSpace(first.AccountID)), + "initial_label": first.Label, + "initial_url": first.URL, + "latest_label": last.Label, + "latest_url": last.URL, + "request_preview": requestPreview, + "response_preview": responsePreview, + "query": query, + "response_truncated": captureChainHasTruncatedResponse(chain), + } +} + +func captureChainIDs(chain captureChain) []string { + out := make([]string, 0, len(chain.Entries)) + for _, entry := range chain.Entries { + out = append(out, entry.ID) + } + return out +} + +func previewCaptureChainRequest(chain captureChain) string { + for _, entry := range chain.Entries { + req := parseCaptureRequestBody(entry.RequestBody) + if prompt := strings.TrimSpace(fieldString(req, "prompt")); prompt != "" { + return previewText(prompt, 280) + } + if messages, ok := req["messages"].([]any); ok { + var parts []string + for _, item := range messages { + m, _ := item.(map[string]any) + content := strings.TrimSpace(fieldString(m, "content")) + if content != "" { + parts = append(parts, content) + } + } + if len(parts) > 0 { + return previewText(strings.Join(parts, "\n"), 280) + } + } + } + return previewText(strings.TrimSpace(chain.Entries[0].RequestBody), 280) +} + +func previewCaptureChainResponse(chain captureChain) string { + var b strings.Builder + for _, entry := range chain.Entries { + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(strings.TrimSpace(entry.ResponseBody)) + if b.Len() >= 280 { + break + } + } + return previewText(b.String(), 280) +} + +func previewText(text string, limit int) string { + text = strings.TrimSpace(text) + if limit <= 0 || len(text) <= limit { + return text + } + return text[:limit] + "..." +} + +func captureChainHasTruncatedResponse(chain captureChain) bool { + for _, entry := range chain.Entries { + if entry.ResponseTruncated { + return true + } + } + return false +} + +func resolveCaptureChainSelection(snapshot []devcapture.Entry, req map[string]any) (captureChain, error) { + chains := buildCaptureChains(snapshot) + if len(chains) == 0 { + return captureChain{}, fmt.Errorf("no capture logs available") + } + + if chainKey := strings.TrimSpace(fieldString(req, "chain_key")); chainKey != "" { + for _, chain := range chains { + if chain.Key == chainKey { + return chain, nil + } + } + return captureChain{}, fmt.Errorf("capture chain not found") + } + + captureID := strings.TrimSpace(fieldString(req, "capture_id")) + if captureID == "" { + if ids, ok := toStringSlice(req["capture_ids"]); ok && len(ids) > 0 { + captureID = strings.TrimSpace(ids[0]) + } + } + if captureID != "" { + for _, chain := range chains { + for _, entry := range chain.Entries { + if entry.ID == captureID { + return chain, nil + } + } + } + return captureChain{}, fmt.Errorf("capture id not found") + } + + query := strings.TrimSpace(fieldString(req, "query")) + if query != "" { + for _, chain := range chains { + if captureChainMatchesQuery(chain, query) { + return chain, nil + } + } + return captureChain{}, fmt.Errorf("no capture chain matched query") + } + + return captureChain{}, fmt.Errorf("capture_id, chain_key, or query is required") +} + +func captureChainRequestPayload(chain captureChain) any { + for _, entry := range chain.Entries { + if req := parseCaptureRequestBody(entry.RequestBody); req != nil { + return req + } + } + return strings.TrimSpace(chain.Entries[0].RequestBody) +} diff --git a/internal/admin/handler_raw_samples_test.go b/internal/admin/handler_raw_samples_test.go index 4566b70..fa15e45 100644 --- a/internal/admin/handler_raw_samples_test.go +++ b/internal/admin/handler_raw_samples_test.go @@ -230,3 +230,127 @@ func TestCombineCaptureBodiesPreservesOrderAndSeparators(t *testing.T) { t.Fatalf("unexpected combined body: %q", string(got)) } } + +func TestQueryRawSampleCapturesGroupsBySessionAndMatchesQuestion(t *testing.T) { + devcapture.Global().Clear() + defer devcapture.Global().Clear() + + recordCapturedResponse( + "deepseek_completion", + "https://chat.deepseek.com/api/v0/chat/completion", + http.StatusOK, + map[string]any{ + "chat_session_id": "session-query-1", + "prompt": "用户问题:广州天气怎么样?", + }, + "data: {\"v\":\"先看天气\"}\n\n", + ) + recordCapturedResponse( + "deepseek_continue", + "https://chat.deepseek.com/api/v0/chat/continue", + http.StatusOK, + map[string]any{ + "chat_session_id": "session-query-1", + "message_id": 2, + }, + "data: {\"v\":\"再补充一点\"}\n\n", + ) + + h := &Handler{} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/dev/raw-samples/query?q=广州天气", nil) + h.queryRawSampleCaptures(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + items, _ := out["items"].([]any) + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d body=%s", len(items), rec.Body.String()) + } + item, _ := items[0].(map[string]any) + if item["chain_key"] != "session:session-query-1" { + t.Fatalf("unexpected chain key: %#v", item["chain_key"]) + } + if int(item["round_count"].(float64)) != 2 { + t.Fatalf("expected 2 rounds, got %#v", item["round_count"]) + } + reqPreview, _ := item["request_preview"].(string) + if !strings.Contains(reqPreview, "广州天气") { + t.Fatalf("expected request preview to contain query, got %q", reqPreview) + } +} + +func TestSaveRawSampleFromCapturesPersistsSelectedChain(t *testing.T) { + root := t.TempDir() + t.Setenv("DS2API_RAW_STREAM_SAMPLE_ROOT", root) + devcapture.Global().Clear() + defer devcapture.Global().Clear() + + recordCapturedResponse( + "deepseek_completion", + "https://chat.deepseek.com/api/v0/chat/completion", + http.StatusOK, + map[string]any{ + "chat_session_id": "session-save-1", + "prompt": "请回答深圳天气", + }, + "data: {\"v\":\"第一段\"}\n\n", + ) + recordCapturedResponse( + "deepseek_continue", + "https://chat.deepseek.com/api/v0/chat/continue", + http.StatusOK, + map[string]any{ + "chat_session_id": "session-save-1", + "message_id": 2, + }, + "data: {\"v\":\"第二段\"}\n\n", + ) + + h := &Handler{} + rec := httptest.NewRecorder() + reqBody := `{"query":"深圳天气","sample_id":"saved-from-memory"}` + req := httptest.NewRequest(http.MethodPost, "/admin/dev/raw-samples/save", strings.NewReader(reqBody)) + h.saveRawSampleFromCaptures(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + if out["sample_id"] != "saved-from-memory" { + t.Fatalf("unexpected sample id: %#v", out["sample_id"]) + } + if int(out["round_count"].(float64)) != 2 { + t.Fatalf("expected round_count=2, got %#v", out["round_count"]) + } + + sampleDir := filepath.Join(root, "saved-from-memory") + upstreamBytes, err := os.ReadFile(filepath.Join(sampleDir, "upstream.stream.sse")) + if err != nil { + t.Fatalf("read upstream: %v", err) + } + upstream := string(upstreamBytes) + if !strings.Contains(upstream, "第一段") || !strings.Contains(upstream, "第二段") { + t.Fatalf("expected combined upstream, got %q", upstream) + } + metaBytes, err := os.ReadFile(filepath.Join(sampleDir, "meta.json")) + if err != nil { + t.Fatalf("read meta: %v", err) + } + var meta map[string]any + if err := json.Unmarshal(metaBytes, &meta); err != nil { + t.Fatalf("decode meta: %v", err) + } + reqMeta, _ := meta["request"].(map[string]any) + if fieldString(reqMeta, "chat_session_id") != "session-save-1" { + t.Fatalf("expected request to come from selected chain, got %#v", meta["request"]) + } +} diff --git a/internal/devcapture/store.go b/internal/devcapture/store.go index 6d0d8cd..c5b3cec 100644 --- a/internal/devcapture/store.go +++ b/internal/devcapture/store.go @@ -14,8 +14,8 @@ import ( ) const ( - defaultLimit = 5 - defaultMaxBodyBytes = 2 * 1024 * 1024 + defaultLimit = 20 + defaultMaxBodyBytes = 5 * 1024 * 1024 maxLimit = 50 ) diff --git a/internal/devcapture/store_test.go b/internal/devcapture/store_test.go index 1dd58b4..3bbbf2d 100644 --- a/internal/devcapture/store_test.go +++ b/internal/devcapture/store_test.go @@ -6,6 +6,35 @@ import ( "testing" ) +func TestNewFromEnvDefaults(t *testing.T) { + t.Setenv("DS2API_DEV_PACKET_CAPTURE_LIMIT", "") + t.Setenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES", "") + t.Setenv("VERCEL", "") + t.Setenv("NOW_REGION", "") + + s := NewFromEnv() + if s.Limit() != 20 { + t.Fatalf("expected default limit 20, got %d", s.Limit()) + } + if s.MaxBodyBytes() != 5*1024*1024 { + t.Fatalf("expected default max body bytes 5MB, got %d", s.MaxBodyBytes()) + } +} + +func TestNewFromEnvHonorsOverrides(t *testing.T) { + t.Setenv("DS2API_DEV_PACKET_CAPTURE_LIMIT", "7") + t.Setenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES", "8192") + t.Setenv("VERCEL", "") + t.Setenv("NOW_REGION", "") + s := NewFromEnv() + if s.Limit() != 7 { + t.Fatalf("expected override limit 7, got %d", s.Limit()) + } + if s.MaxBodyBytes() != 8192 { + t.Fatalf("expected override max body bytes 8192, got %d", s.MaxBodyBytes()) + } +} + func TestStorePushKeepsNewestWithinLimit(t *testing.T) { s := &Store{enabled: true, limit: 2, maxBodyBytes: 1024} for i := 0; i < 3; i++ { diff --git a/webui/src/locales/en.json b/webui/src/locales/en.json index 1047134..6aa6ca0 100644 --- a/webui/src/locales/en.json +++ b/webui/src/locales/en.json @@ -115,7 +115,7 @@ "addAccount": "Add account", "testingAllAccounts": "Refreshing tokens for all accounts...", "sessionActive": "Session active", - "reauthRequired": "Re-auth required", + "reauthRequired": "Retest status required", "runtimeStatusUnknown": "Will be determined after sync", "testStatusFailed": "Last test failed", "noAccounts": "No accounts found.", @@ -325,4 +325,4 @@ "four": "Trigger a redeploy to apply the updated environment variables." } } -} +} \ No newline at end of file diff --git a/webui/src/locales/zh.json b/webui/src/locales/zh.json index 11895c6..9dfc4e4 100644 --- a/webui/src/locales/zh.json +++ b/webui/src/locales/zh.json @@ -115,7 +115,7 @@ "addAccount": "添加账号", "testingAllAccounts": "正在刷新所有账号 Token...", "sessionActive": "已建立会话", - "reauthRequired": "需重新登录", + "reauthRequired": "需重新测试状态", "runtimeStatusUnknown": "状态以同步后为准", "testStatusFailed": "上次测试失败", "noAccounts": "未找到任何账号", @@ -325,4 +325,4 @@ "four": "触发重新部署以应用新的环境变量。" } } -} +} \ No newline at end of file