mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-01 23:15:27 +08:00
feat: centralize DeepSeek SSE parsing, improve account identifier resolution, and simplify CORS configuration.
This commit is contained in:
@@ -106,7 +106,7 @@
|
||||
### 13. CORS 配置矛盾
|
||||
- **位置**: `internal/server/router.go`
|
||||
- **问题**: 同时设置 `Access-Control-Allow-Origin: *` 和 `Access-Control-Allow-Credentials: true` 是无效的(浏览器安全规范)。
|
||||
- **建议**: 动态反射 Origin 或移除 Credentials 允许。
|
||||
- **建议**: 若采用宽松模式,保持 `Access-Control-Allow-Origin: *`,并移除 `Access-Control-Allow-Credentials`。
|
||||
|
||||
---
|
||||
|
||||
@@ -116,4 +116,4 @@
|
||||
|
||||
1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。
|
||||
2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。
|
||||
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 改为动态反射 Origin;Token 估算区分 ASCII/非 ASCII 字符。
|
||||
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 采用宽松模式(`Access-Control-Allow-Origin: *`,不启用 Credentials);Token 估算区分 ASCII/非 ASCII 字符。
|
||||
|
||||
@@ -267,7 +267,6 @@ module.exports = async function handler(req, res) {
|
||||
|
||||
function setCorsHeaders(res) {
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('Access-Control-Allow-Credentials', 'true');
|
||||
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
|
||||
res.setHeader(
|
||||
'Access-Control-Allow-Headers',
|
||||
|
||||
@@ -422,30 +422,21 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
|
||||
if !parsed {
|
||||
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = parsed.NextType
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if doneSignal {
|
||||
finalize("end_turn")
|
||||
if parsed.ErrorMessage != "" {
|
||||
sendError(parsed.ErrorMessage)
|
||||
return
|
||||
}
|
||||
if errObj, hasErr := chunk["error"]; hasErr {
|
||||
sendError(fmt.Sprintf("%v", errObj))
|
||||
return
|
||||
}
|
||||
if code, _ := chunk["code"].(string); code == "content_filter" {
|
||||
sendError("content filtered by upstream")
|
||||
return
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
if parsed.Stop {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range parts {
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -329,26 +329,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
|
||||
if !parsed {
|
||||
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = parsed.NextType
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if doneSignal {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
if _, hasErr := chunk["error"]; hasErr || chunk["code"] == "content_filter" {
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
finalize("content_filter")
|
||||
return
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
if parsed.Stop {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
newChoices := make([]map[string]any, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||
for _, p := range parsed.Parts {
|
||||
if searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -154,20 +154,9 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Concurrent testing with a semaphore to limit parallelism.
|
||||
const maxConcurrency = 5
|
||||
sem := make(chan struct{}, maxConcurrency)
|
||||
results := make([]map[string]any, len(accounts))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, acc := range accounts {
|
||||
wg.Add(1)
|
||||
go func(idx int, account config.Account) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // acquire
|
||||
defer func() { <-sem }() // release
|
||||
results[idx] = h.testAccount(r.Context(), account, model, "")
|
||||
}(i, acc)
|
||||
}
|
||||
wg.Wait()
|
||||
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
||||
return h.testAccount(r.Context(), account, model, "")
|
||||
})
|
||||
|
||||
success := 0
|
||||
for _, res := range results {
|
||||
@@ -178,6 +167,26 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||
}
|
||||
|
||||
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
||||
if maxConcurrency <= 0 {
|
||||
maxConcurrency = 1
|
||||
}
|
||||
sem := make(chan struct{}, maxConcurrency)
|
||||
results := make([]map[string]any, len(accounts))
|
||||
var wg sync.WaitGroup
|
||||
for i, acc := range accounts {
|
||||
wg.Add(1)
|
||||
go func(idx int, account config.Account) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // acquire
|
||||
defer func() { <-sem }() // release
|
||||
results[idx] = testFn(idx, account)
|
||||
}(i, acc)
|
||||
}
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||
start := time.Now()
|
||||
result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package admin
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestToAccountMissingFieldsRemainEmpty(t *testing.T) {
|
||||
acc := toAccount(map[string]any{
|
||||
@@ -26,3 +32,62 @@ func TestFieldStringNilToEmpty(t *testing.T) {
|
||||
t.Fatalf("expected empty string for missing field, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAccountTestsConcurrentlyKeepsInputOrder(t *testing.T) {
|
||||
accounts := []config.Account{
|
||||
{Email: "a@example.com"},
|
||||
{Email: "b@example.com"},
|
||||
{Email: "c@example.com"},
|
||||
}
|
||||
results := runAccountTestsConcurrently(accounts, 2, func(idx int, acc config.Account) map[string]any {
|
||||
return map[string]any{
|
||||
"idx": idx,
|
||||
"account": acc.Identifier(),
|
||||
}
|
||||
})
|
||||
if len(results) != len(accounts) {
|
||||
t.Fatalf("unexpected result length: got %d want %d", len(results), len(accounts))
|
||||
}
|
||||
for i := range accounts {
|
||||
gotIdx, _ := results[i]["idx"].(int)
|
||||
if gotIdx != i {
|
||||
t.Fatalf("result index mismatch at %d: got %d", i, gotIdx)
|
||||
}
|
||||
gotID, _ := results[i]["account"].(string)
|
||||
if gotID != accounts[i].Identifier() {
|
||||
t.Fatalf("result order mismatch at %d: got %q want %q", i, gotID, accounts[i].Identifier())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAccountTestsConcurrentlyRespectsLimit(t *testing.T) {
|
||||
const limit = 3
|
||||
accounts := []config.Account{
|
||||
{Email: "1@example.com"},
|
||||
{Email: "2@example.com"},
|
||||
{Email: "3@example.com"},
|
||||
{Email: "4@example.com"},
|
||||
{Email: "5@example.com"},
|
||||
{Email: "6@example.com"},
|
||||
}
|
||||
var current int32
|
||||
var maxSeen int32
|
||||
_ = runAccountTestsConcurrently(accounts, limit, func(_ int, _ config.Account) map[string]any {
|
||||
c := atomic.AddInt32(¤t, 1)
|
||||
for {
|
||||
m := atomic.LoadInt32(&maxSeen)
|
||||
if c <= m || atomic.CompareAndSwapInt32(&maxSeen, m, c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
atomic.AddInt32(¤t, -1)
|
||||
return map[string]any{"success": true}
|
||||
})
|
||||
if maxSeen > limit {
|
||||
t.Fatalf("concurrency exceeded limit: got %d > %d", maxSeen, limit)
|
||||
}
|
||||
if maxSeen < 2 {
|
||||
t.Fatalf("expected concurrent execution, max seen %d", maxSeen)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,20 +286,35 @@ func (s *Store) FindAccount(identifier string) (Account, bool) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
|
||||
if idx, ok := s.findAccountIndexLocked(identifier); ok {
|
||||
return s.cfg.Accounts[idx], true
|
||||
}
|
||||
return Account{}, false
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
|
||||
s.cfg.Accounts[idx].Token = token
|
||||
return s.saveLocked()
|
||||
idx, ok := s.findAccountIndexLocked(identifier)
|
||||
if !ok {
|
||||
return errors.New("account not found")
|
||||
}
|
||||
return errors.New("account not found")
|
||||
oldID := s.cfg.Accounts[idx].Identifier()
|
||||
s.cfg.Accounts[idx].Token = token
|
||||
newID := s.cfg.Accounts[idx].Identifier()
|
||||
// Keep historical aliases usable for long-lived queues while also adding
|
||||
// the latest identifier after token refresh.
|
||||
if identifier != "" {
|
||||
s.accMap[identifier] = idx
|
||||
}
|
||||
if oldID != "" {
|
||||
s.accMap[oldID] = idx
|
||||
}
|
||||
if newID != "" {
|
||||
s.accMap[newID] = idx
|
||||
}
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Replace(cfg Config) error {
|
||||
@@ -348,6 +363,21 @@ func (s *Store) saveLocked() error {
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
// findAccountIndexLocked expects the store lock to already be held.
|
||||
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
|
||||
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
|
||||
return idx, true
|
||||
}
|
||||
// Fallback for token-only accounts whose derived identifier changed after
|
||||
// a token refresh; this preserves correctness on map misses.
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
if acc.Identifier() == identifier {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
func (s *Store) IsEnvBacked() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -39,3 +39,34 @@ func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
||||
t.Fatalf("unexpected token value: %q", found.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"accounts":[{"token":"old-token"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
before := store.Accounts()
|
||||
if len(before) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(before))
|
||||
}
|
||||
oldID := before[0].Identifier()
|
||||
if oldID == "" {
|
||||
t.Fatal("expected old identifier")
|
||||
}
|
||||
if err := store.UpdateAccountToken(oldID, "new-token"); err != nil {
|
||||
t.Fatalf("update token failed: %v", err)
|
||||
}
|
||||
|
||||
after := store.Accounts()
|
||||
newID := after[0].Identifier()
|
||||
if newID == "" || newID == oldID {
|
||||
t.Fatalf("expected changed identifier, old=%q new=%q", oldID, newID)
|
||||
}
|
||||
if got, ok := store.FindAccount(newID); !ok || got.Token != "new-token" {
|
||||
t.Fatalf("expected find by new identifier")
|
||||
}
|
||||
if got, ok := store.FindAccount(oldID); !ok || got.Token != "new-token" {
|
||||
t.Fatalf("expected find by old identifier alias")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,18 +90,7 @@ func timeout(d time.Duration) func(http.Handler) http.Handler {
|
||||
|
||||
func cors(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
// Dynamically reflect the request origin to allow credentials.
|
||||
// Using "*" with Access-Control-Allow-Credentials: true is
|
||||
// invalid per the CORS spec and will be rejected by browsers.
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
} else {
|
||||
// No Origin header (e.g. server-to-server requests); allow all.
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
if r.Method == http.MethodOptions {
|
||||
|
||||
@@ -31,22 +31,15 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
currentType = "thinking"
|
||||
}
|
||||
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
|
||||
chunk, done, ok := ParseDeepSeekSSELine(line)
|
||||
if !ok {
|
||||
result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||
currentType = result.NextType
|
||||
if !result.Parsed {
|
||||
return true
|
||||
}
|
||||
if done {
|
||||
if result.Stop {
|
||||
return false
|
||||
}
|
||||
if _, hasErr := chunk["error"]; hasErr {
|
||||
return false
|
||||
}
|
||||
parts, finished, newType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
return false
|
||||
}
|
||||
for _, p := range parts {
|
||||
for _, p := range result.Parts {
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
} else {
|
||||
|
||||
49
internal/sse/line.go
Normal file
49
internal/sse/line.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package sse
|
||||
|
||||
import "fmt"
|
||||
|
||||
// LineResult is the normalized parse result for one DeepSeek SSE line.
|
||||
type LineResult struct {
|
||||
Parsed bool
|
||||
Stop bool
|
||||
ContentFilter bool
|
||||
ErrorMessage string
|
||||
Parts []ContentPart
|
||||
NextType string
|
||||
}
|
||||
|
||||
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
|
||||
// streaming and non-streaming handlers.
|
||||
func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType string) LineResult {
|
||||
chunk, done, parsed := ParseDeepSeekSSELine(raw)
|
||||
if !parsed {
|
||||
return LineResult{NextType: currentType}
|
||||
}
|
||||
if done {
|
||||
return LineResult{Parsed: true, Stop: true, NextType: currentType}
|
||||
}
|
||||
if errObj, hasErr := chunk["error"]; hasErr {
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ErrorMessage: fmt.Sprintf("%v", errObj),
|
||||
NextType: currentType,
|
||||
}
|
||||
}
|
||||
if code, _ := chunk["code"].(string); code == "content_filter" {
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
ErrorMessage: "content filtered by upstream",
|
||||
NextType: currentType,
|
||||
}
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
}
|
||||
}
|
||||
37
internal/sse/line_test.go
Normal file
37
internal/sse/line_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package sse
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseDeepSeekContentLineDone(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte("data: [DONE]"), false, "text")
|
||||
if !res.Parsed || !res.Stop {
|
||||
t.Fatalf("expected parsed stop result: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineError(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"error":"boom"}`), false, "text")
|
||||
if !res.Parsed || !res.Stop {
|
||||
t.Fatalf("expected stop on error: %#v", res)
|
||||
}
|
||||
if res.ErrorMessage == "" {
|
||||
t.Fatalf("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text")
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected content-filter stop result: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 1 || res.Parts[0].Text != "hi" || res.Parts[0].Type != "text" {
|
||||
t.Fatalf("unexpected parts: %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user