diff --git a/OPTIMIZATION_REPORT.md b/OPTIMIZATION_REPORT.md index 8054795..ede8440 100644 --- a/OPTIMIZATION_REPORT.md +++ b/OPTIMIZATION_REPORT.md @@ -106,7 +106,7 @@ ### 13. CORS 配置矛盾 - **位置**: `internal/server/router.go` - **问题**: 同时设置 `Access-Control-Allow-Origin: *` 和 `Access-Control-Allow-Credentials: true` 是无效的(浏览器安全规范)。 -- **建议**: 动态反射 Origin 或移除 Credentials 允许。 +- **建议**: 若采用宽松模式,保持 `Access-Control-Allow-Origin: *`,并移除 `Access-Control-Allow-Credentials`。 --- @@ -116,4 +116,4 @@ 1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。 2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。 -3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 改为动态反射 Origin;Token 估算区分 ASCII/非 ASCII 字符。 +3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 采用宽松模式(`Access-Control-Allow-Origin: *`,不启用 Credentials);Token 估算区分 ASCII/非 ASCII 字符。 diff --git a/api/chat-stream.js b/api/chat-stream.js index f69300e..6281b28 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -267,7 +267,6 @@ module.exports = async function handler(req, res) { function setCorsHeaders(res) { res.setHeader('Access-Control-Allow-Origin', '*'); - res.setHeader('Access-Control-Allow-Credentials', 'true'); res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE'); res.setHeader( 'Access-Control-Allow-Headers', diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index c230139..08b8e85 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -422,30 +422,21 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ return } - chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line) - if !parsed { + parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType) + currentType = parsed.NextType + if !parsed.Parsed { continue } - if doneSignal { - finalize("end_turn") + if parsed.ErrorMessage != "" { + sendError(parsed.ErrorMessage) return } - if errObj, hasErr := chunk["error"]; hasErr { - sendError(fmt.Sprintf("%v", errObj)) - return - } - if code, _ := chunk["code"].(string); code == "content_filter" { - sendError("content filtered by upstream") - return - } - parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) - currentType = newType - if finished { + if parsed.Stop { finalize("end_turn") return } - for _, p := range parts { + for _, p := range parsed.Parts { if p.Text == "" { continue } diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 601e152..0df6d11 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -329,26 +329,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt finalize("stop") return } - chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line) - if !parsed { + parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType) + currentType = parsed.NextType + if !parsed.Parsed { continue } - if doneSignal { - finalize("stop") - return - } - if _, hasErr := chunk["error"]; hasErr || chunk["code"] == "content_filter" { + if parsed.ContentFilter || parsed.ErrorMessage != "" { finalize("content_filter") return } - parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) - currentType = newType - if finished { + if parsed.Stop { finalize("stop") return } - newChoices := make([]map[string]any, 0, len(parts)) - for _, p := range parts { + newChoices := make([]map[string]any, 0, len(parsed.Parts)) + for _, p := range parsed.Parts { if searchEnabled && sse.IsCitation(p.Text) { continue } diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts.go index 080e84b..b95077d 100644 --- a/internal/admin/handler_accounts.go +++ b/internal/admin/handler_accounts.go @@ -154,20 +154,9 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) { // Concurrent testing with a semaphore to limit parallelism. const maxConcurrency = 5 - sem := make(chan struct{}, maxConcurrency) - results := make([]map[string]any, len(accounts)) - var wg sync.WaitGroup - - for i, acc := range accounts { - wg.Add(1) - go func(idx int, account config.Account) { - defer wg.Done() - sem <- struct{}{} // acquire - defer func() { <-sem }() // release - results[idx] = h.testAccount(r.Context(), account, model, "") - }(i, acc) - } - wg.Wait() + results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any { + return h.testAccount(r.Context(), account, model, "") + }) success := 0 for _, res := range results { @@ -178,6 +167,26 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results}) } +func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any { + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + sem := make(chan struct{}, maxConcurrency) + results := make([]map[string]any, len(accounts)) + var wg sync.WaitGroup + for i, acc := range accounts { + wg.Add(1) + go func(idx int, account config.Account) { + defer wg.Done() + sem <- struct{}{} // acquire + defer func() { <-sem }() // release + results[idx] = testFn(idx, account) + }(i, acc) + } + wg.Wait() + return results +} + func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any { start := time.Now() result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model} diff --git a/internal/admin/handler_test.go b/internal/admin/handler_test.go index b0df9e3..a31e344 100644 --- a/internal/admin/handler_test.go +++ b/internal/admin/handler_test.go @@ -1,6 +1,12 @@ package admin -import "testing" +import ( + "sync/atomic" + "testing" + "time" + + "ds2api/internal/config" +) func TestToAccountMissingFieldsRemainEmpty(t *testing.T) { acc := toAccount(map[string]any{ @@ -26,3 +32,62 @@ func TestFieldStringNilToEmpty(t *testing.T) { t.Fatalf("expected empty string for missing field, got %q", got) } } + +func TestRunAccountTestsConcurrentlyKeepsInputOrder(t *testing.T) { + accounts := []config.Account{ + {Email: "a@example.com"}, + {Email: "b@example.com"}, + {Email: "c@example.com"}, + } + results := runAccountTestsConcurrently(accounts, 2, func(idx int, acc config.Account) map[string]any { + return map[string]any{ + "idx": idx, + "account": acc.Identifier(), + } + }) + if len(results) != len(accounts) { + t.Fatalf("unexpected result length: got %d want %d", len(results), len(accounts)) + } + for i := range accounts { + gotIdx, _ := results[i]["idx"].(int) + if gotIdx != i { + t.Fatalf("result index mismatch at %d: got %d", i, gotIdx) + } + gotID, _ := results[i]["account"].(string) + if gotID != accounts[i].Identifier() { + t.Fatalf("result order mismatch at %d: got %q want %q", i, gotID, accounts[i].Identifier()) + } + } +} + +func TestRunAccountTestsConcurrentlyRespectsLimit(t *testing.T) { + const limit = 3 + accounts := []config.Account{ + {Email: "1@example.com"}, + {Email: "2@example.com"}, + {Email: "3@example.com"}, + {Email: "4@example.com"}, + {Email: "5@example.com"}, + {Email: "6@example.com"}, + } + var current int32 + var maxSeen int32 + _ = runAccountTestsConcurrently(accounts, limit, func(_ int, _ config.Account) map[string]any { + c := atomic.AddInt32(¤t, 1) + for { + m := atomic.LoadInt32(&maxSeen) + if c <= m || atomic.CompareAndSwapInt32(&maxSeen, m, c) { + break + } + } + time.Sleep(20 * time.Millisecond) + atomic.AddInt32(¤t, -1) + return map[string]any{"success": true} + }) + if maxSeen > limit { + t.Fatalf("concurrency exceeded limit: got %d > %d", maxSeen, limit) + } + if maxSeen < 2 { + t.Fatalf("expected concurrent execution, max seen %d", maxSeen) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 41c1696..691df6d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -286,20 +286,35 @@ func (s *Store) FindAccount(identifier string) (Account, bool) { identifier = strings.TrimSpace(identifier) s.mu.RLock() defer s.mu.RUnlock() - if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) { + if idx, ok := s.findAccountIndexLocked(identifier); ok { return s.cfg.Accounts[idx], true } return Account{}, false } func (s *Store) UpdateAccountToken(identifier, token string) error { + identifier = strings.TrimSpace(identifier) s.mu.Lock() defer s.mu.Unlock() - if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) { - s.cfg.Accounts[idx].Token = token - return s.saveLocked() + idx, ok := s.findAccountIndexLocked(identifier) + if !ok { + return errors.New("account not found") } - return errors.New("account not found") + oldID := s.cfg.Accounts[idx].Identifier() + s.cfg.Accounts[idx].Token = token + newID := s.cfg.Accounts[idx].Identifier() + // Keep historical aliases usable for long-lived queues while also adding + // the latest identifier after token refresh. + if identifier != "" { + s.accMap[identifier] = idx + } + if oldID != "" { + s.accMap[oldID] = idx + } + if newID != "" { + s.accMap[newID] = idx + } + return s.saveLocked() } func (s *Store) Replace(cfg Config) error { @@ -348,6 +363,21 @@ func (s *Store) saveLocked() error { return os.WriteFile(s.path, b, 0o644) } +// findAccountIndexLocked expects the store lock to already be held. +func (s *Store) findAccountIndexLocked(identifier string) (int, bool) { + if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) { + return idx, true + } + // Fallback for token-only accounts whose derived identifier changed after + // a token refresh; this preserves correctness on map misses. + for i, acc := range s.cfg.Accounts { + if acc.Identifier() == identifier { + return i, true + } + } + return -1, false +} + func (s *Store) IsEnvBacked() bool { s.mu.RLock() defer s.mu.RUnlock() diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 1f22cd4..58a8a2a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,3 +39,34 @@ func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) { t.Fatalf("unexpected token value: %q", found.Token) } } + +func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "accounts":[{"token":"old-token"}] + }`) + + store := LoadStore() + before := store.Accounts() + if len(before) != 1 { + t.Fatalf("expected 1 account, got %d", len(before)) + } + oldID := before[0].Identifier() + if oldID == "" { + t.Fatal("expected old identifier") + } + if err := store.UpdateAccountToken(oldID, "new-token"); err != nil { + t.Fatalf("update token failed: %v", err) + } + + after := store.Accounts() + newID := after[0].Identifier() + if newID == "" || newID == oldID { + t.Fatalf("expected changed identifier, old=%q new=%q", oldID, newID) + } + if got, ok := store.FindAccount(newID); !ok || got.Token != "new-token" { + t.Fatalf("expected find by new identifier") + } + if got, ok := store.FindAccount(oldID); !ok || got.Token != "new-token" { + t.Fatalf("expected find by old identifier alias") + } +} diff --git a/internal/server/router.go b/internal/server/router.go index e1260ce..c6339fb 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -90,18 +90,7 @@ func timeout(d time.Duration) func(http.Handler) http.Handler { func cors(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - if origin != "" { - // Dynamically reflect the request origin to allow credentials. - // Using "*" with Access-Control-Allow-Credentials: true is - // invalid per the CORS spec and will be rejected by browsers. - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Vary", "Origin") - } else { - // No Origin header (e.g. server-to-server requests); allow all. - w.Header().Set("Access-Control-Allow-Origin", "*") - } + w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == http.MethodOptions { diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go index 736c1ed..9e0e180 100644 --- a/internal/sse/consumer.go +++ b/internal/sse/consumer.go @@ -31,22 +31,15 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co currentType = "thinking" } _ = deepseek.ScanSSELines(resp, func(line []byte) bool { - chunk, done, ok := ParseDeepSeekSSELine(line) - if !ok { + result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType) + currentType = result.NextType + if !result.Parsed { return true } - if done { + if result.Stop { return false } - if _, hasErr := chunk["error"]; hasErr { - return false - } - parts, finished, newType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) - currentType = newType - if finished { - return false - } - for _, p := range parts { + for _, p := range result.Parts { if p.Type == "thinking" { thinking.WriteString(p.Text) } else { diff --git a/internal/sse/line.go b/internal/sse/line.go new file mode 100644 index 0000000..b71b2b0 --- /dev/null +++ b/internal/sse/line.go @@ -0,0 +1,49 @@ +package sse + +import "fmt" + +// LineResult is the normalized parse result for one DeepSeek SSE line. +type LineResult struct { + Parsed bool + Stop bool + ContentFilter bool + ErrorMessage string + Parts []ContentPart + NextType string +} + +// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both +// streaming and non-streaming handlers. +func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType string) LineResult { + chunk, done, parsed := ParseDeepSeekSSELine(raw) + if !parsed { + return LineResult{NextType: currentType} + } + if done { + return LineResult{Parsed: true, Stop: true, NextType: currentType} + } + if errObj, hasErr := chunk["error"]; hasErr { + return LineResult{ + Parsed: true, + Stop: true, + ErrorMessage: fmt.Sprintf("%v", errObj), + NextType: currentType, + } + } + if code, _ := chunk["code"].(string); code == "content_filter" { + return LineResult{ + Parsed: true, + Stop: true, + ContentFilter: true, + ErrorMessage: "content filtered by upstream", + NextType: currentType, + } + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) + return LineResult{ + Parsed: true, + Stop: finished, + Parts: parts, + NextType: nextType, + } +} diff --git a/internal/sse/line_test.go b/internal/sse/line_test.go new file mode 100644 index 0000000..3292a54 --- /dev/null +++ b/internal/sse/line_test.go @@ -0,0 +1,37 @@ +package sse + +import "testing" + +func TestParseDeepSeekContentLineDone(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("data: [DONE]"), false, "text") + if !res.Parsed || !res.Stop { + t.Fatalf("expected parsed stop result: %#v", res) + } +} + +func TestParseDeepSeekContentLineError(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"error":"boom"}`), false, "text") + if !res.Parsed || !res.Stop { + t.Fatalf("expected stop on error: %#v", res) + } + if res.ErrorMessage == "" { + t.Fatalf("expected non-empty error message") + } +} + +func TestParseDeepSeekContentLineContentFilter(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text") + if !res.Parsed || !res.Stop || !res.ContentFilter { + t.Fatalf("expected content-filter stop result: %#v", res) + } +} + +func TestParseDeepSeekContentLineContent(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text") + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop result: %#v", res) + } + if len(res.Parts) != 1 || res.Parts[0].Text != "hi" || res.Parts[0].Type != "text" { + t.Fatalf("unexpected parts: %#v", res.Parts) + } +}