feat: centralize DeepSeek SSE parsing, improve account identifier resolution, and simplify CORS configuration.

This commit is contained in:
CJACK
2026-02-17 03:45:55 +08:00
parent 2cde0a1d84
commit 23d5ac7fa2
12 changed files with 263 additions and 75 deletions

View File

@@ -106,7 +106,7 @@
### 13. CORS 配置矛盾
- **位置**: `internal/server/router.go`
- **问题**: 同时设置 `Access-Control-Allow-Origin: *` 和 `Access-Control-Allow-Credentials: true` 是无效的(浏览器安全规范)。
- **建议**: 动态反射 Origin 或移除 Credentials 允许
- **建议**: 若采用宽松模式,保持 `Access-Control-Allow-Origin: *`,并移除 `Access-Control-Allow-Credentials`
---
@@ -116,4 +116,4 @@
1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。
2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS改进 Token 估算等微小性能点。~~ CORS 改为动态反射 OriginToken 估算区分 ASCII/非 ASCII 字符。
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS改进 Token 估算等微小性能点。~~ CORS 采用宽松模式(`Access-Control-Allow-Origin: *`,不启用 CredentialsToken 估算区分 ASCII/非 ASCII 字符。

View File

@@ -267,7 +267,6 @@ module.exports = async function handler(req, res) {
function setCorsHeaders(res) {
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('Access-Control-Allow-Credentials', 'true');
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
res.setHeader(
'Access-Control-Allow-Headers',

View File

@@ -422,30 +422,21 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
return
}
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
if !parsed {
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
currentType = parsed.NextType
if !parsed.Parsed {
continue
}
if doneSignal {
finalize("end_turn")
if parsed.ErrorMessage != "" {
sendError(parsed.ErrorMessage)
return
}
if errObj, hasErr := chunk["error"]; hasErr {
sendError(fmt.Sprintf("%v", errObj))
return
}
if code, _ := chunk["code"].(string); code == "content_filter" {
sendError("content filtered by upstream")
return
}
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
currentType = newType
if finished {
if parsed.Stop {
finalize("end_turn")
return
}
for _, p := range parts {
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}

View File

@@ -329,26 +329,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
finalize("stop")
return
}
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
if !parsed {
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
currentType = parsed.NextType
if !parsed.Parsed {
continue
}
if doneSignal {
finalize("stop")
return
}
if _, hasErr := chunk["error"]; hasErr || chunk["code"] == "content_filter" {
if parsed.ContentFilter || parsed.ErrorMessage != "" {
finalize("content_filter")
return
}
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
currentType = newType
if finished {
if parsed.Stop {
finalize("stop")
return
}
newChoices := make([]map[string]any, 0, len(parts))
for _, p := range parts {
newChoices := make([]map[string]any, 0, len(parsed.Parts))
for _, p := range parsed.Parts {
if searchEnabled && sse.IsCitation(p.Text) {
continue
}

View File

@@ -154,20 +154,9 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
// Concurrent testing with a semaphore to limit parallelism.
const maxConcurrency = 5
sem := make(chan struct{}, maxConcurrency)
results := make([]map[string]any, len(accounts))
var wg sync.WaitGroup
for i, acc := range accounts {
wg.Add(1)
go func(idx int, account config.Account) {
defer wg.Done()
sem <- struct{}{} // acquire
defer func() { <-sem }() // release
results[idx] = h.testAccount(r.Context(), account, model, "")
}(i, acc)
}
wg.Wait()
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
return h.testAccount(r.Context(), account, model, "")
})
success := 0
for _, res := range results {
@@ -178,6 +167,26 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
}
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
if maxConcurrency <= 0 {
maxConcurrency = 1
}
sem := make(chan struct{}, maxConcurrency)
results := make([]map[string]any, len(accounts))
var wg sync.WaitGroup
for i, acc := range accounts {
wg.Add(1)
go func(idx int, account config.Account) {
defer wg.Done()
sem <- struct{}{} // acquire
defer func() { <-sem }() // release
results[idx] = testFn(idx, account)
}(i, acc)
}
wg.Wait()
return results
}
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
start := time.Now()
result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model}

View File

@@ -1,6 +1,12 @@
package admin
import "testing"
import (
"sync/atomic"
"testing"
"time"
"ds2api/internal/config"
)
func TestToAccountMissingFieldsRemainEmpty(t *testing.T) {
acc := toAccount(map[string]any{
@@ -26,3 +32,62 @@ func TestFieldStringNilToEmpty(t *testing.T) {
t.Fatalf("expected empty string for missing field, got %q", got)
}
}
func TestRunAccountTestsConcurrentlyKeepsInputOrder(t *testing.T) {
accounts := []config.Account{
{Email: "a@example.com"},
{Email: "b@example.com"},
{Email: "c@example.com"},
}
results := runAccountTestsConcurrently(accounts, 2, func(idx int, acc config.Account) map[string]any {
return map[string]any{
"idx": idx,
"account": acc.Identifier(),
}
})
if len(results) != len(accounts) {
t.Fatalf("unexpected result length: got %d want %d", len(results), len(accounts))
}
for i := range accounts {
gotIdx, _ := results[i]["idx"].(int)
if gotIdx != i {
t.Fatalf("result index mismatch at %d: got %d", i, gotIdx)
}
gotID, _ := results[i]["account"].(string)
if gotID != accounts[i].Identifier() {
t.Fatalf("result order mismatch at %d: got %q want %q", i, gotID, accounts[i].Identifier())
}
}
}
func TestRunAccountTestsConcurrentlyRespectsLimit(t *testing.T) {
const limit = 3
accounts := []config.Account{
{Email: "1@example.com"},
{Email: "2@example.com"},
{Email: "3@example.com"},
{Email: "4@example.com"},
{Email: "5@example.com"},
{Email: "6@example.com"},
}
var current int32
var maxSeen int32
_ = runAccountTestsConcurrently(accounts, limit, func(_ int, _ config.Account) map[string]any {
c := atomic.AddInt32(&current, 1)
for {
m := atomic.LoadInt32(&maxSeen)
if c <= m || atomic.CompareAndSwapInt32(&maxSeen, m, c) {
break
}
}
time.Sleep(20 * time.Millisecond)
atomic.AddInt32(&current, -1)
return map[string]any{"success": true}
})
if maxSeen > limit {
t.Fatalf("concurrency exceeded limit: got %d > %d", maxSeen, limit)
}
if maxSeen < 2 {
t.Fatalf("expected concurrent execution, max seen %d", maxSeen)
}
}

View File

@@ -286,20 +286,35 @@ func (s *Store) FindAccount(identifier string) (Account, bool) {
identifier = strings.TrimSpace(identifier)
s.mu.RLock()
defer s.mu.RUnlock()
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
if idx, ok := s.findAccountIndexLocked(identifier); ok {
return s.cfg.Accounts[idx], true
}
return Account{}, false
}
func (s *Store) UpdateAccountToken(identifier, token string) error {
identifier = strings.TrimSpace(identifier)
s.mu.Lock()
defer s.mu.Unlock()
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
s.cfg.Accounts[idx].Token = token
return s.saveLocked()
idx, ok := s.findAccountIndexLocked(identifier)
if !ok {
return errors.New("account not found")
}
return errors.New("account not found")
oldID := s.cfg.Accounts[idx].Identifier()
s.cfg.Accounts[idx].Token = token
newID := s.cfg.Accounts[idx].Identifier()
// Keep historical aliases usable for long-lived queues while also adding
// the latest identifier after token refresh.
if identifier != "" {
s.accMap[identifier] = idx
}
if oldID != "" {
s.accMap[oldID] = idx
}
if newID != "" {
s.accMap[newID] = idx
}
return s.saveLocked()
}
func (s *Store) Replace(cfg Config) error {
@@ -348,6 +363,21 @@ func (s *Store) saveLocked() error {
return os.WriteFile(s.path, b, 0o644)
}
// findAccountIndexLocked expects the store lock to already be held.
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
return idx, true
}
// Fallback for token-only accounts whose derived identifier changed after
// a token refresh; this preserves correctness on map misses.
for i, acc := range s.cfg.Accounts {
if acc.Identifier() == identifier {
return i, true
}
}
return -1, false
}
func (s *Store) IsEnvBacked() bool {
s.mu.RLock()
defer s.mu.RUnlock()

View File

@@ -39,3 +39,34 @@ func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
t.Fatalf("unexpected token value: %q", found.Token)
}
}
func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"accounts":[{"token":"old-token"}]
}`)
store := LoadStore()
before := store.Accounts()
if len(before) != 1 {
t.Fatalf("expected 1 account, got %d", len(before))
}
oldID := before[0].Identifier()
if oldID == "" {
t.Fatal("expected old identifier")
}
if err := store.UpdateAccountToken(oldID, "new-token"); err != nil {
t.Fatalf("update token failed: %v", err)
}
after := store.Accounts()
newID := after[0].Identifier()
if newID == "" || newID == oldID {
t.Fatalf("expected changed identifier, old=%q new=%q", oldID, newID)
}
if got, ok := store.FindAccount(newID); !ok || got.Token != "new-token" {
t.Fatalf("expected find by new identifier")
}
if got, ok := store.FindAccount(oldID); !ok || got.Token != "new-token" {
t.Fatalf("expected find by old identifier alias")
}
}

View File

@@ -90,18 +90,7 @@ func timeout(d time.Duration) func(http.Handler) http.Handler {
func cors(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" {
// Dynamically reflect the request origin to allow credentials.
// Using "*" with Access-Control-Allow-Credentials: true is
// invalid per the CORS spec and will be rejected by browsers.
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
} else {
// No Origin header (e.g. server-to-server requests); allow all.
w.Header().Set("Access-Control-Allow-Origin", "*")
}
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == http.MethodOptions {

View File

@@ -31,22 +31,15 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
currentType = "thinking"
}
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
chunk, done, ok := ParseDeepSeekSSELine(line)
if !ok {
result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
currentType = result.NextType
if !result.Parsed {
return true
}
if done {
if result.Stop {
return false
}
if _, hasErr := chunk["error"]; hasErr {
return false
}
parts, finished, newType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
currentType = newType
if finished {
return false
}
for _, p := range parts {
for _, p := range result.Parts {
if p.Type == "thinking" {
thinking.WriteString(p.Text)
} else {

49
internal/sse/line.go Normal file
View File

@@ -0,0 +1,49 @@
package sse
import "fmt"
// LineResult is the normalized parse result for one DeepSeek SSE line.
type LineResult struct {
Parsed bool
Stop bool
ContentFilter bool
ErrorMessage string
Parts []ContentPart
NextType string
}
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
// streaming and non-streaming handlers.
func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType string) LineResult {
chunk, done, parsed := ParseDeepSeekSSELine(raw)
if !parsed {
return LineResult{NextType: currentType}
}
if done {
return LineResult{Parsed: true, Stop: true, NextType: currentType}
}
if errObj, hasErr := chunk["error"]; hasErr {
return LineResult{
Parsed: true,
Stop: true,
ErrorMessage: fmt.Sprintf("%v", errObj),
NextType: currentType,
}
}
if code, _ := chunk["code"].(string); code == "content_filter" {
return LineResult{
Parsed: true,
Stop: true,
ContentFilter: true,
ErrorMessage: "content filtered by upstream",
NextType: currentType,
}
}
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
return LineResult{
Parsed: true,
Stop: finished,
Parts: parts,
NextType: nextType,
}
}

37
internal/sse/line_test.go Normal file
View File

@@ -0,0 +1,37 @@
package sse
import "testing"
func TestParseDeepSeekContentLineDone(t *testing.T) {
res := ParseDeepSeekContentLine([]byte("data: [DONE]"), false, "text")
if !res.Parsed || !res.Stop {
t.Fatalf("expected parsed stop result: %#v", res)
}
}
func TestParseDeepSeekContentLineError(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"error":"boom"}`), false, "text")
if !res.Parsed || !res.Stop {
t.Fatalf("expected stop on error: %#v", res)
}
if res.ErrorMessage == "" {
t.Fatalf("expected non-empty error message")
}
}
func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text")
if !res.Parsed || !res.Stop || !res.ContentFilter {
t.Fatalf("expected content-filter stop result: %#v", res)
}
}
func TestParseDeepSeekContentLineContent(t *testing.T) {
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
if !res.Parsed || res.Stop {
t.Fatalf("expected parsed non-stop result: %#v", res)
}
if len(res.Parts) != 1 || res.Parts[0].Text != "hi" || res.Parts[0].Type != "text" {
t.Fatalf("unexpected parts: %#v", res.Parts)
}
}