refactor backend API structure

This commit is contained in:
CJACK
2026-04-26 06:58:20 +08:00
parent 8a91fef6ab
commit abc96a37d8
207 changed files with 2675 additions and 1344 deletions

View File

@@ -0,0 +1,27 @@
package rawsamples
import (
"net/http"
"ds2api/internal/chathistory"
adminshared "ds2api/internal/httpapi/admin/shared"
)
type Handler struct {
Store adminshared.ConfigStore
Pool adminshared.PoolController
DS adminshared.DeepSeekCaller
OpenAI adminshared.OpenAIChatCaller
ChatHistory *chathistory.Store
}
var writeJSON = adminshared.WriteJSON
func intFromQuery(r *http.Request, key string, d int) int {
return adminshared.IntFromQuery(r, key, d)
}
func nilIfEmpty(s string) any { return adminshared.NilIfEmpty(s) }
func toStringSlice(v any) ([]string, bool) { return adminshared.ToStringSlice(v) }
func fieldString(m map[string]any, key string) string {
return adminshared.FieldString(m, key)
}

View File

@@ -0,0 +1,549 @@
package rawsamples
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strings"
"ds2api/internal/config"
"ds2api/internal/devcapture"
adminshared "ds2api/internal/httpapi/admin/shared"
"ds2api/internal/rawsample"
)
type captureChain struct {
Key string
Entries []devcapture.Entry
}
func (h *Handler) captureRawSample(w http.ResponseWriter, r *http.Request) {
if h.OpenAI == nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": "OpenAI handler is not configured"})
return
}
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
payload, sampleID, apiKey, err := prepareRawSampleCaptureRequest(h.Store, req)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
body, err := json.Marshal(payload)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": "failed to encode capture request"})
return
}
traceID := rawsample.NormalizeSampleID(sampleID)
if traceID == "" {
traceID = rawsample.DefaultSampleID("capture")
}
before := devcapture.Global().Snapshot()
rec := httptest.NewRecorder()
captureReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions?__trace_id="+url.QueryEscape(traceID), bytes.NewReader(body))
captureReq.Header.Set("Authorization", "Bearer "+apiKey)
captureReq.Header.Set("Content-Type", "application/json")
h.OpenAI.ChatCompletions(rec, captureReq)
after := devcapture.Global().Snapshot()
if rec.Code >= http.StatusBadRequest {
copyHeader(w.Header(), rec.Header())
w.WriteHeader(rec.Code)
_, _ = io.Copy(w, bytes.NewReader(rec.Body.Bytes()))
return
}
captureEntries, err := collectNewCaptureEntries(before, after)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
saved, err := rawsample.Persist(rawsample.PersistOptions{
RootDir: config.RawStreamSampleRoot(),
SampleID: sampleID,
Source: "admin/dev/raw-samples/capture",
Request: payload,
Capture: captureSummaryFromEntries(captureEntries),
UpstreamBody: combineCaptureBodies(captureEntries),
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
copyHeader(w.Header(), rec.Header())
w.Header().Set("X-Ds2-Sample-Id", saved.SampleID)
w.Header().Set("X-Ds2-Sample-Dir", saved.Dir)
w.Header().Set("X-Ds2-Sample-Meta", saved.MetaPath)
w.Header().Set("X-Ds2-Sample-Upstream", saved.UpstreamPath)
w.WriteHeader(rec.Code)
_, _ = io.Copy(w, bytes.NewReader(rec.Body.Bytes()))
}
func prepareRawSampleCaptureRequest(store adminshared.ConfigStore, req map[string]any) (map[string]any, string, string, error) {
payload := cloneMap(req)
sampleID := strings.TrimSpace(fieldString(payload, "sample_id"))
apiKey := strings.TrimSpace(fieldString(payload, "api_key"))
for _, k := range []string{"sample_id", "api_key", "promote_default", "persist", "source"} {
delete(payload, k)
}
if apiKey == "" {
if store == nil {
return nil, "", "", fmt.Errorf("no api key provided")
}
keys := store.Keys()
if len(keys) == 0 {
return nil, "", "", fmt.Errorf("no api key available")
}
apiKey = strings.TrimSpace(keys[0])
}
if model := strings.TrimSpace(fieldString(payload, "model")); model == "" {
payload["model"] = "deepseek-v4-flash"
}
if _, ok := payload["stream"]; !ok {
payload["stream"] = true
}
if messagesRaw, ok := payload["messages"].([]any); !ok || len(messagesRaw) == 0 {
message := strings.TrimSpace(fieldString(payload, "message"))
if message == "" {
message = "你好"
}
payload["messages"] = []map[string]any{{"role": "user", "content": message}}
}
delete(payload, "message")
if sampleID == "" {
model := strings.TrimSpace(fieldString(payload, "model"))
if model == "" {
model = "capture"
}
sampleID = rawsample.DefaultSampleID(model)
}
return payload, sampleID, apiKey, nil
}
func collectNewCaptureEntries(before, after []devcapture.Entry) ([]devcapture.Entry, error) {
beforeIDs := make(map[string]struct{}, len(before))
for _, entry := range before {
beforeIDs[entry.ID] = struct{}{}
}
entries := make([]devcapture.Entry, 0, len(after))
for _, entry := range after {
if _, ok := beforeIDs[entry.ID]; ok {
continue
}
if strings.TrimSpace(entry.ResponseBody) == "" {
continue
}
entries = append(entries, entry)
}
if len(entries) == 0 {
return nil, fmt.Errorf("no upstream capture was recorded")
}
// Snapshot order is newest-first; reverse to preserve the actual request order.
for i, j := 0, len(entries)-1; i < j; i, j = i+1, j-1 {
entries[i], entries[j] = entries[j], entries[i]
}
return entries, nil
}
func captureSummaryFromEntries(entries []devcapture.Entry) rawsample.CaptureSummary {
if len(entries) == 0 {
return rawsample.CaptureSummary{}
}
// Primary metadata comes from the first (initial) capture.
summary := rawsample.CaptureSummary{
Label: strings.TrimSpace(entries[0].Label),
URL: strings.TrimSpace(entries[0].URL),
StatusCode: entries[0].StatusCode,
}
// Record every round (initial + continuations) so replay/debug
// can reconstruct the full multi-round interaction.
totalBytes := 0
rounds := make([]rawsample.CaptureRound, 0, len(entries))
for _, entry := range entries {
n := len(entry.ResponseBody)
totalBytes += n
rounds = append(rounds, rawsample.CaptureRound{
Label: strings.TrimSpace(entry.Label),
URL: strings.TrimSpace(entry.URL),
StatusCode: entry.StatusCode,
ResponseBytes: n,
})
}
summary.ResponseBytes = totalBytes
if len(rounds) > 1 {
summary.Rounds = rounds
}
return summary
}
func combineCaptureBodies(entries []devcapture.Entry) []byte {
if len(entries) == 0 {
return nil
}
var buf bytes.Buffer
for _, entry := range entries {
if buf.Len() > 0 {
last := buf.Bytes()[buf.Len()-1]
if last != '\n' {
buf.WriteByte('\n')
}
}
buf.WriteString(entry.ResponseBody)
}
return buf.Bytes()
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
dst.Del(k)
for _, v := range vv {
dst.Add(k, v)
}
}
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func (h *Handler) queryRawSampleCaptures(w http.ResponseWriter, r *http.Request) {
query := strings.TrimSpace(r.URL.Query().Get("q"))
limit := intFromQuery(r, "limit", 20)
if limit <= 0 {
limit = 20
}
if limit > 50 {
limit = 50
}
chains := buildCaptureChains(devcapture.Global().Snapshot())
items := make([]map[string]any, 0, len(chains))
for _, chain := range chains {
if query != "" && !captureChainMatchesQuery(chain, query) {
continue
}
items = append(items, buildCaptureChainQueryItem(chain, query))
if len(items) >= limit {
break
}
}
writeJSON(w, http.StatusOK, map[string]any{
"query": query,
"limit": limit,
"count": len(items),
"items": items,
})
}
func (h *Handler) saveRawSampleFromCaptures(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
snapshot := devcapture.Global().Snapshot()
if len(snapshot) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "no capture logs available"})
return
}
chain, err := resolveCaptureChainSelection(snapshot, req)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
sampleID := strings.TrimSpace(fieldString(req, "sample_id"))
source := strings.TrimSpace(fieldString(req, "source"))
if source == "" {
source = "admin/dev/raw-samples/save"
}
requestPayload := captureChainRequestPayload(chain)
saved, err := rawsample.Persist(rawsample.PersistOptions{
RootDir: config.RawStreamSampleRoot(),
SampleID: sampleID,
Source: source,
Request: requestPayload,
Capture: captureSummaryFromEntries(chain.Entries),
UpstreamBody: combineCaptureBodies(chain.Entries),
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"sample_id": saved.SampleID,
"sample_dir": saved.Dir,
"meta_path": saved.MetaPath,
"upstream_path": saved.UpstreamPath,
"chain_key": chain.Key,
"capture_ids": captureChainIDs(chain),
"round_count": len(chain.Entries),
})
}
func buildCaptureChains(snapshot []devcapture.Entry) []captureChain {
if len(snapshot) == 0 {
return nil
}
ordered := make([]devcapture.Entry, len(snapshot))
// devcapture snapshots are newest-first because the store prepends entries.
// Reverse once so equal-second timestamps can preserve the actual capture
// order (completion before continue) under the stable CreatedAt sort below.
for i := range snapshot {
ordered[len(snapshot)-1-i] = snapshot[i]
}
sort.SliceStable(ordered, func(i, j int) bool {
return ordered[i].CreatedAt < ordered[j].CreatedAt
})
byKey := make(map[string]*captureChain, len(ordered))
keys := make([]string, 0, len(ordered))
for _, entry := range ordered {
key := captureChainKey(entry)
if key == "" {
key = "capture:" + entry.ID
}
if _, ok := byKey[key]; !ok {
byKey[key] = &captureChain{Key: key}
keys = append(keys, key)
}
byKey[key].Entries = append(byKey[key].Entries, entry)
}
chains := make([]captureChain, 0, len(keys))
for _, key := range keys {
chains = append(chains, *byKey[key])
}
sort.SliceStable(chains, func(i, j int) bool {
return latestCreatedAt(chains[i]) > latestCreatedAt(chains[j])
})
return chains
}
func captureChainKey(entry devcapture.Entry) string {
req := parseCaptureRequestBody(entry.RequestBody)
if sessionID := strings.TrimSpace(fieldString(req, "chat_session_id")); sessionID != "" {
return "session:" + sessionID
}
return "capture:" + entry.ID
}
func parseCaptureRequestBody(raw string) map[string]any {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
var out map[string]any
if err := json.Unmarshal([]byte(raw), &out); err != nil {
return nil
}
return out
}
func latestCreatedAt(chain captureChain) int64 {
var latest int64
for _, entry := range chain.Entries {
if entry.CreatedAt > latest {
latest = entry.CreatedAt
}
}
return latest
}
func captureChainMatchesQuery(chain captureChain, query string) bool {
query = strings.ToLower(strings.TrimSpace(query))
if query == "" {
return true
}
for _, entry := range chain.Entries {
hay := strings.ToLower(strings.Join([]string{
entry.Label,
entry.URL,
entry.AccountID,
entry.RequestBody,
entry.ResponseBody,
}, "\n"))
if strings.Contains(hay, query) {
return true
}
}
return false
}
func buildCaptureChainQueryItem(chain captureChain, query string) map[string]any {
first := chain.Entries[0]
last := chain.Entries[len(chain.Entries)-1]
requestPreview := previewCaptureChainRequest(chain)
responsePreview := previewCaptureChainResponse(chain)
return map[string]any{
"chain_key": chain.Key,
"capture_ids": captureChainIDs(chain),
"created_at": latestCreatedAt(chain),
"round_count": len(chain.Entries),
"account_id": nilIfEmpty(strings.TrimSpace(first.AccountID)),
"initial_label": first.Label,
"initial_url": first.URL,
"latest_label": last.Label,
"latest_url": last.URL,
"request_preview": requestPreview,
"response_preview": responsePreview,
"query": query,
"response_truncated": captureChainHasTruncatedResponse(chain),
}
}
func captureChainIDs(chain captureChain) []string {
out := make([]string, 0, len(chain.Entries))
for _, entry := range chain.Entries {
out = append(out, entry.ID)
}
return out
}
func previewCaptureChainRequest(chain captureChain) string {
for _, entry := range chain.Entries {
req := parseCaptureRequestBody(entry.RequestBody)
if prompt := strings.TrimSpace(fieldString(req, "prompt")); prompt != "" {
return previewText(prompt, 280)
}
if messages, ok := req["messages"].([]any); ok {
var parts []string
for _, item := range messages {
m, _ := item.(map[string]any)
content := strings.TrimSpace(fieldString(m, "content"))
if content != "" {
parts = append(parts, content)
}
}
if len(parts) > 0 {
return previewText(strings.Join(parts, "\n"), 280)
}
}
}
return previewText(strings.TrimSpace(chain.Entries[0].RequestBody), 280)
}
func previewCaptureChainResponse(chain captureChain) string {
var b strings.Builder
for _, entry := range chain.Entries {
if b.Len() > 0 {
b.WriteByte('\n')
}
b.WriteString(strings.TrimSpace(entry.ResponseBody))
if b.Len() >= 280 {
break
}
}
return previewText(b.String(), 280)
}
func previewText(text string, limit int) string {
text = strings.TrimSpace(text)
if limit <= 0 || len(text) <= limit {
return text
}
return text[:limit] + "..."
}
func captureChainHasTruncatedResponse(chain captureChain) bool {
for _, entry := range chain.Entries {
if entry.ResponseTruncated {
return true
}
}
return false
}
func resolveCaptureChainSelection(snapshot []devcapture.Entry, req map[string]any) (captureChain, error) {
chains := buildCaptureChains(snapshot)
if len(chains) == 0 {
return captureChain{}, fmt.Errorf("no capture logs available")
}
if chainKey := strings.TrimSpace(fieldString(req, "chain_key")); chainKey != "" {
for _, chain := range chains {
if chain.Key == chainKey {
return chain, nil
}
}
return captureChain{}, fmt.Errorf("capture chain not found")
}
captureID := strings.TrimSpace(fieldString(req, "capture_id"))
if captureID == "" {
if ids, ok := toStringSlice(req["capture_ids"]); ok && len(ids) > 0 {
captureID = strings.TrimSpace(ids[0])
}
}
if captureID != "" {
for _, chain := range chains {
for _, entry := range chain.Entries {
if entry.ID == captureID {
return chain, nil
}
}
}
return captureChain{}, fmt.Errorf("capture id not found")
}
query := strings.TrimSpace(fieldString(req, "query"))
if query != "" {
for _, chain := range chains {
if captureChainMatchesQuery(chain, query) {
return chain, nil
}
}
return captureChain{}, fmt.Errorf("no capture chain matched query")
}
return captureChain{}, fmt.Errorf("capture_id, chain_key, or query is required")
}
func captureChainRequestPayload(chain captureChain) any {
for _, entry := range chain.Entries {
if req := parseCaptureRequestBody(entry.RequestBody); req != nil {
return req
}
}
return strings.TrimSpace(chain.Entries[0].RequestBody)
}

View File

@@ -0,0 +1,389 @@
package rawsamples
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"ds2api/internal/devcapture"
)
type stubOpenAIChatCaller struct{}
func (stubOpenAIChatCaller) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
store := devcapture.Global()
session := store.Start("deepseek_completion", "https://chat.deepseek.com/api/v0/chat/completion", "acct-test", map[string]any{"model": "deepseek-v4-flash"})
raw := io.NopCloser(strings.NewReader(
"data: {\"v\":\"hello [reference:1]\"}\n\n" +
"data: {\"v\":\"FINISHED\",\"p\":\"response/status\"}\n\n",
))
if session != nil {
raw = session.WrapBody(raw, http.StatusOK)
}
_, _ = io.ReadAll(raw)
_ = raw.Close()
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"},\"index\":0}],\"created\":1,\"id\":\"id\",\"model\":\"m\",\"object\":\"chat.completion.chunk\"}\n\n")
}
type stubOpenAIChatCallerWithContinuations struct{}
func (stubOpenAIChatCallerWithContinuations) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
recordCapturedResponse("deepseek_completion", "https://chat.deepseek.com/api/v0/chat/completion", http.StatusOK, map[string]any{"model": "deepseek-v4-flash"}, "data: {\"v\":\"hello [reference:1]\"}\n\n"+"data: [DONE]\n\n")
recordCapturedResponse("deepseek_continue", "https://chat.deepseek.com/api/v0/chat/continue", http.StatusOK, map[string]any{"chat_session_id": "session-1", "message_id": 2}, "data: {\"v\":\"continued\"}\n\n"+"data: [DONE]\n\n")
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"hello continued\"},\"index\":0}],\"created\":1,\"id\":\"id\",\"model\":\"m\",\"object\":\"chat.completion.chunk\"}\n\n")
}
type stubOpenAIChatCallerWithoutCapture struct{}
func (stubOpenAIChatCallerWithoutCapture) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"},\"index\":0}],\"created\":1,\"id\":\"id\",\"model\":\"m\",\"object\":\"chat.completion.chunk\"}\n\n")
}
func recordCapturedResponse(label, rawURL string, statusCode int, request any, body string) {
store := devcapture.Global()
session := store.Start(label, rawURL, "acct-test", request)
raw := io.NopCloser(strings.NewReader(body))
if session != nil {
raw = session.WrapBody(raw, statusCode)
}
_, _ = io.ReadAll(raw)
_ = raw.Close()
}
func TestCaptureRawSampleWritesPersistentSample(t *testing.T) {
t.Setenv("DS2API_RAW_STREAM_SAMPLE_ROOT", t.TempDir())
devcapture.Global().Clear()
defer devcapture.Global().Clear()
h := &Handler{OpenAI: stubOpenAIChatCaller{}}
reqBody := `{
"sample_id":"My Sample 01",
"api_key":"local-key",
"model":"deepseek-v4-flash",
"message":"广州天气",
"stream":true
}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/admin/dev/raw-samples/capture", strings.NewReader(reqBody))
h.captureRawSample(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if got := rec.Header().Get("X-Ds2-Sample-Id"); got != "my-sample-01" {
t.Fatalf("expected sample id header my-sample-01, got %q", got)
}
if got := rec.Header().Get("X-Ds2-Sample-Upstream"); got != filepath.Join(os.Getenv("DS2API_RAW_STREAM_SAMPLE_ROOT"), "my-sample-01", "upstream.stream.sse") {
t.Fatalf("unexpected sample upstream header: %q", got)
}
if !strings.Contains(rec.Body.String(), `"content":"hello"`) {
t.Fatalf("expected proxied openai output, got %s", rec.Body.String())
}
sampleDir := filepath.Join(os.Getenv("DS2API_RAW_STREAM_SAMPLE_ROOT"), "my-sample-01")
if _, err := os.Stat(sampleDir); err != nil {
t.Fatalf("sample dir missing: %v", err)
}
metaBytes, err := os.ReadFile(filepath.Join(sampleDir, "meta.json"))
if err != nil {
t.Fatalf("read meta: %v", err)
}
var meta map[string]any
if err := json.Unmarshal(metaBytes, &meta); err != nil {
t.Fatalf("decode meta: %v", err)
}
if meta["sample_id"] != "my-sample-01" {
t.Fatalf("unexpected meta sample_id: %#v", meta["sample_id"])
}
capture, _ := meta["capture"].(map[string]any)
if capture == nil {
t.Fatalf("missing capture meta: %#v", meta)
}
if got := int(capture["response_bytes"].(float64)); got == 0 {
t.Fatalf("expected capture bytes to be recorded, got %#v", capture)
}
if _, ok := meta["processed"]; ok {
t.Fatalf("unexpected processed meta: %#v", meta["processed"])
}
}
func TestCaptureRawSampleCombinesContinuationCaptures(t *testing.T) {
t.Setenv("DS2API_RAW_STREAM_SAMPLE_ROOT", t.TempDir())
devcapture.Global().Clear()
defer devcapture.Global().Clear()
h := &Handler{OpenAI: stubOpenAIChatCallerWithContinuations{}}
reqBody := `{
"sample_id":"My Sample 02",
"api_key":"local-key",
"model":"deepseek-v4-flash",
"message":"广州天气",
"stream":true
}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/admin/dev/raw-samples/capture", strings.NewReader(reqBody))
h.captureRawSample(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
sampleDir := filepath.Join(os.Getenv("DS2API_RAW_STREAM_SAMPLE_ROOT"), "my-sample-02")
upstreamBytes, err := os.ReadFile(filepath.Join(sampleDir, "upstream.stream.sse"))
if err != nil {
t.Fatalf("read upstream: %v", err)
}
upstream := string(upstreamBytes)
if !strings.Contains(upstream, "hello [reference:1]") {
t.Fatalf("expected initial capture in combined upstream, got %s", upstream)
}
if !strings.Contains(upstream, "continued") {
t.Fatalf("expected continuation capture in combined upstream, got %s", upstream)
}
if strings.Index(upstream, "hello [reference:1]") > strings.Index(upstream, "continued") {
t.Fatalf("expected initial capture before continuation, got %s", upstream)
}
metaBytes, err := os.ReadFile(filepath.Join(sampleDir, "meta.json"))
if err != nil {
t.Fatalf("read meta: %v", err)
}
var meta map[string]any
if err := json.Unmarshal(metaBytes, &meta); err != nil {
t.Fatalf("decode meta: %v", err)
}
capture, _ := meta["capture"].(map[string]any)
if capture == nil {
t.Fatalf("missing capture meta: %#v", meta)
}
if got := int(capture["response_bytes"].(float64)); got != len(upstreamBytes) {
t.Fatalf("expected combined response_bytes %d, got %#v", len(upstreamBytes), capture["response_bytes"])
}
rounds, _ := capture["rounds"].([]any)
if len(rounds) != 2 {
t.Fatalf("expected 2 capture rounds, got %d: %#v", len(rounds), capture)
}
r0, _ := rounds[0].(map[string]any)
r1, _ := rounds[1].(map[string]any)
if r0["label"] != "deepseek_completion" {
t.Fatalf("expected first round label deepseek_completion, got %v", r0["label"])
}
if r1["label"] != "deepseek_continue" {
t.Fatalf("expected second round label deepseek_continue, got %v", r1["label"])
}
}
func TestCaptureRawSampleReturnsErrorWhenNoNewCaptureRecorded(t *testing.T) {
root := t.TempDir()
t.Setenv("DS2API_RAW_STREAM_SAMPLE_ROOT", root)
devcapture.Global().Clear()
defer devcapture.Global().Clear()
recordCapturedResponse("preexisting", "https://chat.deepseek.com/api/v0/chat/completion", http.StatusOK, map[string]any{"model": "deepseek-v4-flash"}, "data: {\"v\":\"old\"}\n\n")
h := &Handler{OpenAI: stubOpenAIChatCallerWithoutCapture{}}
reqBody := `{
"sample_id":"My Sample 03",
"api_key":"local-key",
"model":"deepseek-v4-flash",
"message":"广州天气",
"stream":true
}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/admin/dev/raw-samples/capture", strings.NewReader(reqBody))
h.captureRawSample(rec, req)
if rec.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d body=%s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "no upstream capture was recorded") {
t.Fatalf("expected no-capture error, got %s", rec.Body.String())
}
if _, err := os.Stat(filepath.Join(root, "my-sample-03")); !os.IsNotExist(err) {
t.Fatalf("expected no sample dir to be created, stat err=%v", err)
}
}
func TestCombineCaptureBodiesPreservesOrderAndSeparators(t *testing.T) {
entries := []devcapture.Entry{
{ResponseBody: "first"},
{ResponseBody: "second"},
}
got := combineCaptureBodies(entries)
if !bytes.Equal(got, []byte("first\nsecond")) {
t.Fatalf("unexpected combined body: %q", string(got))
}
}
func TestQueryRawSampleCapturesGroupsBySessionAndMatchesQuestion(t *testing.T) {
devcapture.Global().Clear()
defer devcapture.Global().Clear()
recordCapturedResponse(
"deepseek_completion",
"https://chat.deepseek.com/api/v0/chat/completion",
http.StatusOK,
map[string]any{
"chat_session_id": "session-query-1",
"prompt": "用户问题:广州天气怎么样?",
},
"data: {\"v\":\"先看天气\"}\n\n",
)
recordCapturedResponse(
"deepseek_continue",
"https://chat.deepseek.com/api/v0/chat/continue",
http.StatusOK,
map[string]any{
"chat_session_id": "session-query-1",
"message_id": 2,
},
"data: {\"v\":\"再补充一点\"}\n\n",
)
h := &Handler{}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/admin/dev/raw-samples/query?q=广州天气", nil)
h.queryRawSampleCaptures(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode failed: %v", err)
}
items, _ := out["items"].([]any)
if len(items) != 1 {
t.Fatalf("expected 1 item, got %d body=%s", len(items), rec.Body.String())
}
item, _ := items[0].(map[string]any)
if item["chain_key"] != "session:session-query-1" {
t.Fatalf("unexpected chain key: %#v", item["chain_key"])
}
if int(item["round_count"].(float64)) != 2 {
t.Fatalf("expected 2 rounds, got %#v", item["round_count"])
}
reqPreview, _ := item["request_preview"].(string)
if !strings.Contains(reqPreview, "广州天气") {
t.Fatalf("expected request preview to contain query, got %q", reqPreview)
}
}
func TestBuildCaptureChainsPreservesCaptureOrderWhenTimestampsCollide(t *testing.T) {
snapshot := []devcapture.Entry{
{
ID: "cap_continue",
CreatedAt: 1712365200,
Label: "deepseek_continue",
RequestBody: `{"chat_session_id":"session-collision","message_id":2}`,
ResponseBody: "data: {\"v\":\"第二段\"}\n\n",
},
{
ID: "cap_completion",
CreatedAt: 1712365200,
Label: "deepseek_completion",
RequestBody: `{"chat_session_id":"session-collision","prompt":"题目"}`,
ResponseBody: "data: {\"v\":\"第一段\"}\n\n",
},
}
chains := buildCaptureChains(snapshot)
if len(chains) != 1 {
t.Fatalf("expected 1 chain, got %d", len(chains))
}
if len(chains[0].Entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(chains[0].Entries))
}
if chains[0].Entries[0].Label != "deepseek_completion" {
t.Fatalf("expected completion first, got %#v", chains[0].Entries)
}
if chains[0].Entries[1].Label != "deepseek_continue" {
t.Fatalf("expected continue second, got %#v", chains[0].Entries)
}
}
func TestSaveRawSampleFromCapturesPersistsSelectedChain(t *testing.T) {
root := t.TempDir()
t.Setenv("DS2API_RAW_STREAM_SAMPLE_ROOT", root)
devcapture.Global().Clear()
defer devcapture.Global().Clear()
recordCapturedResponse(
"deepseek_completion",
"https://chat.deepseek.com/api/v0/chat/completion",
http.StatusOK,
map[string]any{
"chat_session_id": "session-save-1",
"prompt": "请回答深圳天气",
},
"data: {\"v\":\"第一段\"}\n\n",
)
recordCapturedResponse(
"deepseek_continue",
"https://chat.deepseek.com/api/v0/chat/continue",
http.StatusOK,
map[string]any{
"chat_session_id": "session-save-1",
"message_id": 2,
},
"data: {\"v\":\"第二段\"}\n\n",
)
h := &Handler{}
rec := httptest.NewRecorder()
reqBody := `{"query":"深圳天气","sample_id":"saved-from-memory"}`
req := httptest.NewRequest(http.MethodPost, "/admin/dev/raw-samples/save", strings.NewReader(reqBody))
h.saveRawSampleFromCaptures(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode failed: %v", err)
}
if out["sample_id"] != "saved-from-memory" {
t.Fatalf("unexpected sample id: %#v", out["sample_id"])
}
if int(out["round_count"].(float64)) != 2 {
t.Fatalf("expected round_count=2, got %#v", out["round_count"])
}
sampleDir := filepath.Join(root, "saved-from-memory")
upstreamBytes, err := os.ReadFile(filepath.Join(sampleDir, "upstream.stream.sse"))
if err != nil {
t.Fatalf("read upstream: %v", err)
}
upstream := string(upstreamBytes)
if !strings.Contains(upstream, "第一段") || !strings.Contains(upstream, "第二段") {
t.Fatalf("expected combined upstream, got %q", upstream)
}
metaBytes, err := os.ReadFile(filepath.Join(sampleDir, "meta.json"))
if err != nil {
t.Fatalf("read meta: %v", err)
}
var meta map[string]any
if err := json.Unmarshal(metaBytes, &meta); err != nil {
t.Fatalf("decode meta: %v", err)
}
reqMeta, _ := meta["request"].(map[string]any)
if fieldString(reqMeta, "chat_session_id") != "session-save-1" {
t.Fatalf("expected request to come from selected chain, got %#v", meta["request"])
}
}

View File

@@ -0,0 +1,9 @@
package rawsamples
import "github.com/go-chi/chi/v5"
func RegisterRoutes(r chi.Router, h *Handler) {
r.Post("/dev/raw-samples/capture", h.captureRawSample)
r.Get("/dev/raw-samples/query", h.queryRawSampleCaptures)
r.Post("/dev/raw-samples/save", h.saveRawSampleFromCaptures)
}