feat: centralize utility functions, abstract SSE stream collection, and add concurrency to admin account testing.

This commit is contained in:
CJACK
2026-02-17 03:31:19 +08:00
parent 4251438ff5
commit 534fd1d14b
11 changed files with 186 additions and 197 deletions

View File

@@ -115,5 +115,5 @@
为了稳健地优化项目,建议按照以下顺序执行:
1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。
2. **Phase 2 (Refactor):** 统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。
3. **Phase 3 (Cleanup):** 清理重复工具函数,优化 CORS改进 Token 估算等微小性能点。
2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。
3. **Phase 3 (Cleanup):** 优化 CORS改进 Token 估算等微小性能点。

View File

@@ -18,6 +18,9 @@ import (
"ds2api/internal/util"
)
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
var writeJSON = util.WriteJSON
type Handler struct {
Store *config.Store
Auth *auth.Resolver
@@ -113,11 +116,13 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
}
toolNames := extractClaudeToolNames(toolsRequested)
if toBool(req["stream"]) {
if util.ToBool(req["stream"]) {
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
return
}
fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled)
result := sse.CollectStream(resp, thinkingEnabled, true)
fullText := result.Text
fullThinking := result.Thinking
detected := util.ParseToolCalls(fullText, toolNames)
content := make([]map[string]any, 0, 4)
if fullThinking != "" {
@@ -198,41 +203,6 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}
func collectDeepSeek(resp *http.Response, thinkingEnabled bool) (string, string) {
defer resp.Body.Close()
text := strings.Builder{}
thinking := strings.Builder{}
currentType := "text"
if thinkingEnabled {
currentType = "thinking"
}
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 2*1024*1024)
for scanner.Scan() {
chunk, done, ok := sse.ParseDeepSeekSSELine(scanner.Bytes())
if !ok {
continue
}
if done {
break
}
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
currentType = newType
if finished {
break
}
for _, p := range parts {
if p.Type == "thinking" {
thinking.WriteString(p.Text)
} else {
text.WriteString(p.Text)
}
}
}
return text.String(), thinking.String()
}
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
@@ -657,14 +627,3 @@ func cloneMap(in map[string]any) map[string]any {
}
return out
}
func toBool(v any) bool {
b, _ := v.(bool)
return b
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}

View File

@@ -1,6 +1,7 @@
package claude
import (
"ds2api/internal/sse"
"encoding/json"
"io"
"net/http"
@@ -241,12 +242,12 @@ func TestCollectDeepSeekRegression(t *testing.T) {
`data: {"p":"response/content","v":"答"}`,
`data: [DONE]`,
)
text, thinking := collectDeepSeek(resp, true)
if thinking != "想" {
t.Fatalf("unexpected thinking: %q", thinking)
result := sse.CollectStream(resp, true, true)
if result.Thinking != "想" {
t.Fatalf("unexpected thinking: %q", result.Thinking)
}
if text != "答" {
t.Fatalf("unexpected text: %q", text)
if result.Text != "答" {
t.Fatalf("unexpected text: %q", result.Text)
}
}

View File

@@ -20,6 +20,10 @@ import (
"ds2api/internal/util"
)
// writeJSON is a package-internal alias kept to avoid mass-renaming across
// every call-site in this file. It delegates to the shared util version.
var writeJSON = util.WriteJSON
type Handler struct {
Store *config.Store
Auth *auth.Resolver
@@ -117,7 +121,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if toBool(req["stream"]) {
if util.ToBool(req["stream"]) {
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
return
}
@@ -125,50 +129,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, string(body))
return
}
thinking := strings.Builder{}
text := strings.Builder{}
currentType := "text"
if thinkingEnabled {
currentType = "thinking"
}
_ = ctx
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
chunk, done, ok := sse.ParseDeepSeekSSELine(line)
if !ok {
return true
}
if done {
return false
}
if _, hasErr := chunk["error"]; hasErr {
return false
}
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
currentType = newType
if finished {
return false
}
for _, p := range parts {
if searchEnabled && sse.IsCitation(p.Text) {
continue
}
if p.Type == "thinking" {
thinking.WriteString(p.Text)
} else {
text.WriteString(p.Text)
}
}
return true
})
result := sse.CollectStream(resp, thinkingEnabled, true)
finalThinking := thinking.String()
finalText := text.String()
finalThinking := result.Thinking
finalText := result.Text
detected := util.ParseToolCalls(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
@@ -507,19 +478,6 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
return messages, names
}
func toBool(v any) bool {
if b, ok := v.(bool); ok {
return b
}
return false
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]any{
"error": map[string]any{

View File

@@ -52,7 +52,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
if !toBool(req["stream"]) {
if !util.ToBool(req["stream"]) {
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}

View File

@@ -1,7 +1,6 @@
package admin
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -9,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
@@ -151,15 +151,29 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
return
}
results := make([]map[string]any, 0, len(accounts))
// Concurrent testing with a semaphore to limit parallelism.
const maxConcurrency = 5
sem := make(chan struct{}, maxConcurrency)
results := make([]map[string]any, len(accounts))
var wg sync.WaitGroup
for i, acc := range accounts {
wg.Add(1)
go func(idx int, account config.Account) {
defer wg.Done()
sem <- struct{}{} // acquire
defer func() { <-sem }() // release
results[idx] = h.testAccount(r.Context(), account, model, "")
}(i, acc)
}
wg.Wait()
success := 0
for _, acc := range accounts {
res := h.testAccount(r.Context(), acc, model, "")
for _, res := range results {
if ok, _ := res["success"].(bool); ok {
success++
}
results = append(results, res)
time.Sleep(time.Second)
}
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
}
@@ -204,6 +218,7 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
if !ok {
thinking, search = false, false
}
_ = search
pow, err := h.DS.GetPow(ctx, authCtx, 1)
if err != nil {
result["message"] = "获取 PoW 失败: " + err.Error()
@@ -215,50 +230,21 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
result["message"] = "请求失败: " + err.Error()
return result
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
return result
}
text := strings.Builder{}
think := strings.Builder{}
currentType := "text"
if thinking {
currentType = "thinking"
}
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 2*1024*1024)
for scanner.Scan() {
chunk, done, parsed := sse.ParseDeepSeekSSELine(scanner.Bytes())
if !parsed {
continue
}
if done {
break
}
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinking, currentType)
currentType = newType
if finished {
break
}
for _, p := range parts {
if p.Type == "thinking" {
think.WriteString(p.Text)
} else {
text.WriteString(p.Text)
}
}
}
collected := sse.CollectStream(resp, thinking, true)
result["success"] = true
result["response_time"] = int(time.Since(start).Milliseconds())
if text.Len() > 0 {
result["message"] = text.String()
if collected.Text != "" {
result["message"] = collected.Text
} else {
result["message"] = "(无回复内容)"
}
if think.Len() > 0 {
result["thinking"] = think.String()
if collected.Thinking != "" {
result["thinking"] = collected.Thinking
}
return result
}

View File

@@ -1,15 +1,19 @@
package admin
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"ds2api/internal/config"
"ds2api/internal/util"
)
// writeJSON and intFrom are package-internal aliases for the shared util versions.
var writeJSON = util.WriteJSON
var intFrom = util.IntFrom
func reverseAccounts(a []config.Account) {
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
a[i], a[j] = a[j], a[i]
@@ -28,19 +32,6 @@ func intFromQuery(r *http.Request, key string, d int) int {
return n
}
func intFrom(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
case int64:
return int(n)
default:
return 0
}
}
func nilIfEmpty(s string) any {
if s == "" {
return nil
@@ -90,9 +81,3 @@ func statusOr(v int, d int) int {
}
return v
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}

View File

@@ -154,6 +154,8 @@ type Store struct {
cfg Config
path string
fromEnv bool
keyMap map[string]struct{} // O(1) API key lookup index
accMap map[string]int // O(1) account lookup: identifier -> slice index
}
func BaseDir() string {
@@ -199,7 +201,24 @@ func LoadStore() *Store {
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
Logger.Warn("[config] empty config loaded")
}
return &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
s.rebuildIndexes()
return s
}
// rebuildIndexes must be called with the lock already held (or during init).
func (s *Store) rebuildIndexes() {
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
for _, k := range s.cfg.Keys {
s.keyMap[k] = struct{}{}
}
s.accMap = make(map[string]int, len(s.cfg.Accounts))
for i, acc := range s.cfg.Accounts {
id := acc.Identifier()
if id != "" {
s.accMap[id] = i
}
}
}
func loadConfig() (Config, bool, error) {
@@ -247,12 +266,8 @@ func (s *Store) Snapshot() Config {
func (s *Store) HasAPIKey(k string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
for _, key := range s.cfg.Keys {
if key == k {
return true
}
}
return false
_, ok := s.keyMap[k]
return ok
}
func (s *Store) Keys() []string {
@@ -271,10 +286,8 @@ func (s *Store) FindAccount(identifier string) (Account, bool) {
identifier = strings.TrimSpace(identifier)
s.mu.RLock()
defer s.mu.RUnlock()
for _, acc := range s.cfg.Accounts {
if acc.Identifier() == identifier {
return acc, true
}
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
return s.cfg.Accounts[idx], true
}
return Account{}, false
}
@@ -282,11 +295,9 @@ func (s *Store) FindAccount(identifier string) (Account, bool) {
func (s *Store) UpdateAccountToken(identifier, token string) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.cfg.Accounts {
if s.cfg.Accounts[i].Identifier() == identifier {
s.cfg.Accounts[i].Token = token
return s.saveLocked()
}
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
s.cfg.Accounts[idx].Token = token
return s.saveLocked()
}
return errors.New("account not found")
}
@@ -295,6 +306,7 @@ func (s *Store) Replace(cfg Config) error {
s.mu.Lock()
defer s.mu.Unlock()
s.cfg = cfg.Clone()
s.rebuildIndexes()
return s.saveLocked()
}
@@ -306,6 +318,7 @@ func (s *Store) Update(mutator func(*Config) error) error {
return err
}
s.cfg = cfg
s.rebuildIndexes()
return s.saveLocked()
}

View File

@@ -16,10 +16,14 @@ import (
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
"ds2api/internal/util"
"github.com/andybalholm/brotli"
)
// intFrom is a package-internal alias for the shared util version.
var intFrom = util.IntFrom
type Client struct {
Store *config.Store
Auth *auth.Resolver
@@ -288,19 +292,6 @@ func isTokenInvalid(status int, code int, msg string) bool {
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
}
func intFrom(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
case int64:
return int(n)
default:
return 0
}
}
func readResponseBody(resp *http.Response) ([]byte, error) {
encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
var reader io.Reader = resp.Body

59
internal/sse/consumer.go Normal file
View File

@@ -0,0 +1,59 @@
package sse
import (
"net/http"
"strings"
"ds2api/internal/deepseek"
)
// CollectResult holds the aggregated text and thinking content from a
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
type CollectResult struct {
Text string
Thinking string
}
// CollectStream fully consumes a DeepSeek SSE response and separates
// thinking content from text content. This replaces the duplicated
// stream-collection logic in openai.handleNonStream, claude.collectDeepSeek,
// and admin.testAccount.
//
// The caller is responsible for closing resp.Body unless closeBody is true.
func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) CollectResult {
if closeBody {
defer resp.Body.Close()
}
text := strings.Builder{}
thinking := strings.Builder{}
currentType := "text"
if thinkingEnabled {
currentType = "thinking"
}
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
chunk, done, ok := ParseDeepSeekSSELine(line)
if !ok {
return true
}
if done {
return false
}
if _, hasErr := chunk["error"]; hasErr {
return false
}
parts, finished, newType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
currentType = newType
if finished {
return false
}
for _, p := range parts {
if p.Type == "thinking" {
thinking.WriteString(p.Text)
} else {
text.WriteString(p.Text)
}
}
return true
})
return CollectResult{Text: text.String(), Thinking: thinking.String()}
}

37
internal/util/helpers.go Normal file
View File

@@ -0,0 +1,37 @@
package util
import (
"encoding/json"
"net/http"
)
// WriteJSON writes a JSON response with the given status code.
// This is a shared helper to avoid duplicate writeJSON functions
// in openai, claude, and admin packages.
func WriteJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
// ToBool loosely converts an interface value to bool.
func ToBool(v any) bool {
if b, ok := v.(bool); ok {
return b
}
return false
}
// IntFrom converts a JSON-decoded numeric value (float64, int, int64) to int.
func IntFrom(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
case int64:
return int(n)
default:
return 0
}
}