diff --git a/OPTIMIZATION_REPORT.md b/OPTIMIZATION_REPORT.md index cb307f2..1fda1a6 100644 --- a/OPTIMIZATION_REPORT.md +++ b/OPTIMIZATION_REPORT.md @@ -115,5 +115,5 @@ 为了稳健地优化项目,建议按照以下顺序执行: 1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。 -2. **Phase 2 (Refactor):** 统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。 -3. **Phase 3 (Cleanup):** 清理重复工具函数,优化 CORS,改进 Token 估算等微小性能点。 +2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。 +3. **Phase 3 (Cleanup):** 优化 CORS,改进 Token 估算等微小性能点。 diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index 19d32ae..c230139 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -18,6 +18,9 @@ import ( "ds2api/internal/util" ) +// writeJSON is a package-internal alias to avoid mass-renaming all call-sites. +var writeJSON = util.WriteJSON + type Handler struct { Store *config.Store Auth *auth.Resolver @@ -113,11 +116,13 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { } toolNames := extractClaudeToolNames(toolsRequested) - if toBool(req["stream"]) { + if util.ToBool(req["stream"]) { h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames) return } - fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled) + result := sse.CollectStream(resp, thinkingEnabled, true) + fullText := result.Text + fullThinking := result.Thinking detected := util.ParseToolCalls(fullText, toolNames) content := make([]map[string]any, 0, 4) if fullThinking != "" { @@ -198,41 +203,6 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) } -func collectDeepSeek(resp *http.Response, thinkingEnabled bool) (string, string) { - defer resp.Body.Close() - text := strings.Builder{} - thinking := strings.Builder{} - currentType := "text" - if thinkingEnabled { - currentType = "thinking" - } - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - chunk, done, ok := sse.ParseDeepSeekSSELine(scanner.Bytes()) - if !ok { - continue - } - if done { - break - } - parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) - currentType = newType - if finished { - break - } - for _, p := range parts { - if p.Type == "thinking" { - thinking.WriteString(p.Text) - } else { - text.WriteString(p.Text) - } - } - } - return text.String(), thinking.String() -} - func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -657,14 +627,3 @@ func cloneMap(in map[string]any) map[string]any { } return out } - -func toBool(v any) bool { - b, _ := v.(bool) - return b -} - -func writeJSON(w http.ResponseWriter, status int, payload any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(payload) -} diff --git a/internal/adapter/claude/handler_stream_test.go b/internal/adapter/claude/handler_stream_test.go index 56f7ea9..74086ae 100644 --- a/internal/adapter/claude/handler_stream_test.go +++ b/internal/adapter/claude/handler_stream_test.go @@ -1,6 +1,7 @@ package claude import ( + "ds2api/internal/sse" "encoding/json" "io" "net/http" @@ -241,12 +242,12 @@ func TestCollectDeepSeekRegression(t *testing.T) { `data: {"p":"response/content","v":"答"}`, `data: [DONE]`, ) - text, thinking := collectDeepSeek(resp, true) - if thinking != "想" { - t.Fatalf("unexpected thinking: %q", thinking) + result := sse.CollectStream(resp, true, true) + if result.Thinking != "想" { + t.Fatalf("unexpected thinking: %q", result.Thinking) } - if text != "答" { - t.Fatalf("unexpected text: %q", text) + if result.Text != "答" { + t.Fatalf("unexpected text: %q", result.Text) } } diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index d78bff3..601e152 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -20,6 +20,10 @@ import ( "ds2api/internal/util" ) +// writeJSON is a package-internal alias kept to avoid mass-renaming across +// every call-site in this file. It delegates to the shared util version. +var writeJSON = util.WriteJSON + type Handler struct { Store *config.Store Auth *auth.Resolver @@ -117,7 +121,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") return } - if toBool(req["stream"]) { + if util.ToBool(req["stream"]) { h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) return } @@ -125,50 +129,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { } func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) writeOpenAIError(w, resp.StatusCode, string(body)) return } - thinking := strings.Builder{} - text := strings.Builder{} - currentType := "text" - if thinkingEnabled { - currentType = "thinking" - } _ = ctx - _ = deepseek.ScanSSELines(resp, func(line []byte) bool { - chunk, done, ok := sse.ParseDeepSeekSSELine(line) - if !ok { - return true - } - if done { - return false - } - if _, hasErr := chunk["error"]; hasErr { - return false - } - parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) - currentType = newType - if finished { - return false - } - for _, p := range parts { - if searchEnabled && sse.IsCitation(p.Text) { - continue - } - if p.Type == "thinking" { - thinking.WriteString(p.Text) - } else { - text.WriteString(p.Text) - } - } - return true - }) + result := sse.CollectStream(resp, thinkingEnabled, true) - finalThinking := thinking.String() - finalText := text.String() + finalThinking := result.Thinking + finalText := result.Text detected := util.ParseToolCalls(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} @@ -507,19 +478,6 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, return messages, names } -func toBool(v any) bool { - if b, ok := v.(bool); ok { - return b - } - return false -} - -func writeJSON(w http.ResponseWriter, status int, payload any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(payload) -} - func writeOpenAIError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]any{ "error": map[string]any{ diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 3e75f47..7fceb93 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -52,7 +52,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - if !toBool(req["stream"]) { + if !util.ToBool(req["stream"]) { writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts.go index 3dc43d1..080e84b 100644 --- a/internal/admin/handler_accounts.go +++ b/internal/admin/handler_accounts.go @@ -1,7 +1,6 @@ package admin import ( - "bufio" "bytes" "context" "encoding/json" @@ -9,6 +8,7 @@ import ( "io" "net/http" "strings" + "sync" "time" "github.com/go-chi/chi/v5" @@ -151,15 +151,29 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}}) return } - results := make([]map[string]any, 0, len(accounts)) + + // Concurrent testing with a semaphore to limit parallelism. + const maxConcurrency = 5 + sem := make(chan struct{}, maxConcurrency) + results := make([]map[string]any, len(accounts)) + var wg sync.WaitGroup + + for i, acc := range accounts { + wg.Add(1) + go func(idx int, account config.Account) { + defer wg.Done() + sem <- struct{}{} // acquire + defer func() { <-sem }() // release + results[idx] = h.testAccount(r.Context(), account, model, "") + }(i, acc) + } + wg.Wait() + success := 0 - for _, acc := range accounts { - res := h.testAccount(r.Context(), acc, model, "") + for _, res := range results { if ok, _ := res["success"].(bool); ok { success++ } - results = append(results, res) - time.Sleep(time.Second) } writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results}) } @@ -204,6 +218,7 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me if !ok { thinking, search = false, false } + _ = search pow, err := h.DS.GetPow(ctx, authCtx, 1) if err != nil { result["message"] = "获取 PoW 失败: " + err.Error() @@ -215,50 +230,21 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me result["message"] = "请求失败: " + err.Error() return result } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode) return result } - text := strings.Builder{} - think := strings.Builder{} - currentType := "text" - if thinking { - currentType = "thinking" - } - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - chunk, done, parsed := sse.ParseDeepSeekSSELine(scanner.Bytes()) - if !parsed { - continue - } - if done { - break - } - parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinking, currentType) - currentType = newType - if finished { - break - } - for _, p := range parts { - if p.Type == "thinking" { - think.WriteString(p.Text) - } else { - text.WriteString(p.Text) - } - } - } + collected := sse.CollectStream(resp, thinking, true) result["success"] = true result["response_time"] = int(time.Since(start).Milliseconds()) - if text.Len() > 0 { - result["message"] = text.String() + if collected.Text != "" { + result["message"] = collected.Text } else { result["message"] = "(无回复内容)" } - if think.Len() > 0 { - result["thinking"] = think.String() + if collected.Thinking != "" { + result["thinking"] = collected.Thinking } return result } diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index a1d21b8..fa75b59 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -1,15 +1,19 @@ package admin import ( - "encoding/json" "fmt" "net/http" "strconv" "strings" "ds2api/internal/config" + "ds2api/internal/util" ) +// writeJSON and intFrom are package-internal aliases for the shared util versions. +var writeJSON = util.WriteJSON +var intFrom = util.IntFrom + func reverseAccounts(a []config.Account) { for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { a[i], a[j] = a[j], a[i] @@ -28,19 +32,6 @@ func intFromQuery(r *http.Request, key string, d int) int { return n } -func intFrom(v any) int { - switch n := v.(type) { - case float64: - return int(n) - case int: - return n - case int64: - return int(n) - default: - return 0 - } -} - func nilIfEmpty(s string) any { if s == "" { return nil @@ -90,9 +81,3 @@ func statusOr(v int, d int) int { } return v } - -func writeJSON(w http.ResponseWriter, status int, payload any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(payload) -} diff --git a/internal/config/config.go b/internal/config/config.go index 49b5857..41c1696 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -154,6 +154,8 @@ type Store struct { cfg Config path string fromEnv bool + keyMap map[string]struct{} // O(1) API key lookup index + accMap map[string]int // O(1) account lookup: identifier -> slice index } func BaseDir() string { @@ -199,7 +201,24 @@ func LoadStore() *Store { if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { Logger.Warn("[config] empty config loaded") } - return &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} + s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} + s.rebuildIndexes() + return s +} + +// rebuildIndexes must be called with the lock already held (or during init). +func (s *Store) rebuildIndexes() { + s.keyMap = make(map[string]struct{}, len(s.cfg.Keys)) + for _, k := range s.cfg.Keys { + s.keyMap[k] = struct{}{} + } + s.accMap = make(map[string]int, len(s.cfg.Accounts)) + for i, acc := range s.cfg.Accounts { + id := acc.Identifier() + if id != "" { + s.accMap[id] = i + } + } } func loadConfig() (Config, bool, error) { @@ -247,12 +266,8 @@ func (s *Store) Snapshot() Config { func (s *Store) HasAPIKey(k string) bool { s.mu.RLock() defer s.mu.RUnlock() - for _, key := range s.cfg.Keys { - if key == k { - return true - } - } - return false + _, ok := s.keyMap[k] + return ok } func (s *Store) Keys() []string { @@ -271,10 +286,8 @@ func (s *Store) FindAccount(identifier string) (Account, bool) { identifier = strings.TrimSpace(identifier) s.mu.RLock() defer s.mu.RUnlock() - for _, acc := range s.cfg.Accounts { - if acc.Identifier() == identifier { - return acc, true - } + if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) { + return s.cfg.Accounts[idx], true } return Account{}, false } @@ -282,11 +295,9 @@ func (s *Store) FindAccount(identifier string) (Account, bool) { func (s *Store) UpdateAccountToken(identifier, token string) error { s.mu.Lock() defer s.mu.Unlock() - for i := range s.cfg.Accounts { - if s.cfg.Accounts[i].Identifier() == identifier { - s.cfg.Accounts[i].Token = token - return s.saveLocked() - } + if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) { + s.cfg.Accounts[idx].Token = token + return s.saveLocked() } return errors.New("account not found") } @@ -295,6 +306,7 @@ func (s *Store) Replace(cfg Config) error { s.mu.Lock() defer s.mu.Unlock() s.cfg = cfg.Clone() + s.rebuildIndexes() return s.saveLocked() } @@ -306,6 +318,7 @@ func (s *Store) Update(mutator func(*Config) error) error { return err } s.cfg = cfg + s.rebuildIndexes() return s.saveLocked() } diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go index be596df..c42443d 100644 --- a/internal/deepseek/client.go +++ b/internal/deepseek/client.go @@ -16,10 +16,14 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" trans "ds2api/internal/deepseek/transport" + "ds2api/internal/util" "github.com/andybalholm/brotli" ) +// intFrom is a package-internal alias for the shared util version. +var intFrom = util.IntFrom + type Client struct { Store *config.Store Auth *auth.Resolver @@ -288,19 +292,6 @@ func isTokenInvalid(status int, code int, msg string) bool { return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") } -func intFrom(v any) int { - switch n := v.(type) { - case float64: - return int(n) - case int: - return n - case int64: - return int(n) - default: - return 0 - } -} - func readResponseBody(resp *http.Response) ([]byte, error) { encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) var reader io.Reader = resp.Body diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go new file mode 100644 index 0000000..736c1ed --- /dev/null +++ b/internal/sse/consumer.go @@ -0,0 +1,59 @@ +package sse + +import ( + "net/http" + "strings" + + "ds2api/internal/deepseek" +) + +// CollectResult holds the aggregated text and thinking content from a +// DeepSeek SSE stream, consumed to completion (non-streaming use case). +type CollectResult struct { + Text string + Thinking string +} + +// CollectStream fully consumes a DeepSeek SSE response and separates +// thinking content from text content. This replaces the duplicated +// stream-collection logic in openai.handleNonStream, claude.collectDeepSeek, +// and admin.testAccount. +// +// The caller is responsible for closing resp.Body unless closeBody is true. +func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) CollectResult { + if closeBody { + defer resp.Body.Close() + } + text := strings.Builder{} + thinking := strings.Builder{} + currentType := "text" + if thinkingEnabled { + currentType = "thinking" + } + _ = deepseek.ScanSSELines(resp, func(line []byte) bool { + chunk, done, ok := ParseDeepSeekSSELine(line) + if !ok { + return true + } + if done { + return false + } + if _, hasErr := chunk["error"]; hasErr { + return false + } + parts, finished, newType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType) + currentType = newType + if finished { + return false + } + for _, p := range parts { + if p.Type == "thinking" { + thinking.WriteString(p.Text) + } else { + text.WriteString(p.Text) + } + } + return true + }) + return CollectResult{Text: text.String(), Thinking: thinking.String()} +} diff --git a/internal/util/helpers.go b/internal/util/helpers.go new file mode 100644 index 0000000..15e6de7 --- /dev/null +++ b/internal/util/helpers.go @@ -0,0 +1,37 @@ +package util + +import ( + "encoding/json" + "net/http" +) + +// WriteJSON writes a JSON response with the given status code. +// This is a shared helper to avoid duplicate writeJSON functions +// in openai, claude, and admin packages. +func WriteJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} + +// ToBool loosely converts an interface value to bool. +func ToBool(v any) bool { + if b, ok := v.(bool); ok { + return b + } + return false +} + +// IntFrom converts a JSON-decoded numeric value (float64, int, int64) to int. +func IntFrom(v any) int { + switch n := v.(type) { + case float64: + return int(n) + case int: + return n + case int64: + return int(n) + default: + return 0 + } +}