mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-18 23:25:10 +08:00
feat: centralize DeepSeek SSE parsing, improve account identifier resolution, and simplify CORS configuration.
This commit is contained in:
@@ -106,7 +106,7 @@
|
|||||||
### 13. CORS 配置矛盾
|
### 13. CORS 配置矛盾
|
||||||
- **位置**: `internal/server/router.go`
|
- **位置**: `internal/server/router.go`
|
||||||
- **问题**: 同时设置 `Access-Control-Allow-Origin: *` 和 `Access-Control-Allow-Credentials: true` 是无效的(浏览器安全规范)。
|
- **问题**: 同时设置 `Access-Control-Allow-Origin: *` 和 `Access-Control-Allow-Credentials: true` 是无效的(浏览器安全规范)。
|
||||||
- **建议**: 动态反射 Origin 或移除 Credentials 允许。
|
- **建议**: 若采用宽松模式,保持 `Access-Control-Allow-Origin: *`,并移除 `Access-Control-Allow-Credentials`。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -116,4 +116,4 @@
|
|||||||
|
|
||||||
1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。
|
1. **Phase 1 (Fix Critical) ✅ 已完成:** ~~修复 `Save()` 锁问题、WASM 重复创建、Admin 默认密码警告、Graceful Shutdown。删除无用大文件。~~ 同时修复了 `itoa` 低效实现。
|
||||||
2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。
|
2. **Phase 2 (Refactor) ✅ 已完成:** ~~统一 API Key/Account 的索引机制,重构 SSE 解析逻辑 (DRY),优化 `testAllAccounts` 并发。~~ 同时完成了重复工具函数的统一清理(`writeJSON`/`toBool`/`intFrom` → `internal/util`)。
|
||||||
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 改为动态反射 Origin;Token 估算区分 ASCII/非 ASCII 字符。
|
3. **Phase 3 (Cleanup) ✅ 已完成:** ~~优化 CORS,改进 Token 估算等微小性能点。~~ CORS 采用宽松模式(`Access-Control-Allow-Origin: *`,不启用 Credentials);Token 估算区分 ASCII/非 ASCII 字符。
|
||||||
|
|||||||
@@ -267,7 +267,6 @@ module.exports = async function handler(req, res) {
|
|||||||
|
|
||||||
function setCorsHeaders(res) {
|
function setCorsHeaders(res) {
|
||||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||||
res.setHeader('Access-Control-Allow-Credentials', 'true');
|
|
||||||
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
|
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
|
||||||
res.setHeader(
|
res.setHeader(
|
||||||
'Access-Control-Allow-Headers',
|
'Access-Control-Allow-Headers',
|
||||||
|
|||||||
@@ -422,30 +422,21 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
|
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||||
if !parsed {
|
currentType = parsed.NextType
|
||||||
|
if !parsed.Parsed {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if doneSignal {
|
if parsed.ErrorMessage != "" {
|
||||||
finalize("end_turn")
|
sendError(parsed.ErrorMessage)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errObj, hasErr := chunk["error"]; hasErr {
|
if parsed.Stop {
|
||||||
sendError(fmt.Sprintf("%v", errObj))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if code, _ := chunk["code"].(string); code == "content_filter" {
|
|
||||||
sendError("content filtered by upstream")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
|
||||||
currentType = newType
|
|
||||||
if finished {
|
|
||||||
finalize("end_turn")
|
finalize("end_turn")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range parts {
|
for _, p := range parsed.Parts {
|
||||||
if p.Text == "" {
|
if p.Text == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -329,26 +329,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
|||||||
finalize("stop")
|
finalize("stop")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
chunk, doneSignal, parsed := sse.ParseDeepSeekSSELine(line)
|
parsed := sse.ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||||
if !parsed {
|
currentType = parsed.NextType
|
||||||
|
if !parsed.Parsed {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if doneSignal {
|
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||||
finalize("stop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, hasErr := chunk["error"]; hasErr || chunk["code"] == "content_filter" {
|
|
||||||
finalize("content_filter")
|
finalize("content_filter")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
if parsed.Stop {
|
||||||
currentType = newType
|
|
||||||
if finished {
|
|
||||||
finalize("stop")
|
finalize("stop")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
newChoices := make([]map[string]any, 0, len(parts))
|
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||||
for _, p := range parts {
|
for _, p := range parsed.Parts {
|
||||||
if searchEnabled && sse.IsCitation(p.Text) {
|
if searchEnabled && sse.IsCitation(p.Text) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,20 +154,9 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Concurrent testing with a semaphore to limit parallelism.
|
// Concurrent testing with a semaphore to limit parallelism.
|
||||||
const maxConcurrency = 5
|
const maxConcurrency = 5
|
||||||
sem := make(chan struct{}, maxConcurrency)
|
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
||||||
results := make([]map[string]any, len(accounts))
|
return h.testAccount(r.Context(), account, model, "")
|
||||||
var wg sync.WaitGroup
|
})
|
||||||
|
|
||||||
for i, acc := range accounts {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(idx int, account config.Account) {
|
|
||||||
defer wg.Done()
|
|
||||||
sem <- struct{}{} // acquire
|
|
||||||
defer func() { <-sem }() // release
|
|
||||||
results[idx] = h.testAccount(r.Context(), account, model, "")
|
|
||||||
}(i, acc)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
success := 0
|
success := 0
|
||||||
for _, res := range results {
|
for _, res := range results {
|
||||||
@@ -178,6 +167,26 @@ func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
||||||
|
if maxConcurrency <= 0 {
|
||||||
|
maxConcurrency = 1
|
||||||
|
}
|
||||||
|
sem := make(chan struct{}, maxConcurrency)
|
||||||
|
results := make([]map[string]any, len(accounts))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, acc := range accounts {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int, account config.Account) {
|
||||||
|
defer wg.Done()
|
||||||
|
sem <- struct{}{} // acquire
|
||||||
|
defer func() { <-sem }() // release
|
||||||
|
results[idx] = testFn(idx, account)
|
||||||
|
}(i, acc)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model}
|
result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model}
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
func TestToAccountMissingFieldsRemainEmpty(t *testing.T) {
|
func TestToAccountMissingFieldsRemainEmpty(t *testing.T) {
|
||||||
acc := toAccount(map[string]any{
|
acc := toAccount(map[string]any{
|
||||||
@@ -26,3 +32,62 @@ func TestFieldStringNilToEmpty(t *testing.T) {
|
|||||||
t.Fatalf("expected empty string for missing field, got %q", got)
|
t.Fatalf("expected empty string for missing field, got %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunAccountTestsConcurrentlyKeepsInputOrder(t *testing.T) {
|
||||||
|
accounts := []config.Account{
|
||||||
|
{Email: "a@example.com"},
|
||||||
|
{Email: "b@example.com"},
|
||||||
|
{Email: "c@example.com"},
|
||||||
|
}
|
||||||
|
results := runAccountTestsConcurrently(accounts, 2, func(idx int, acc config.Account) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"idx": idx,
|
||||||
|
"account": acc.Identifier(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if len(results) != len(accounts) {
|
||||||
|
t.Fatalf("unexpected result length: got %d want %d", len(results), len(accounts))
|
||||||
|
}
|
||||||
|
for i := range accounts {
|
||||||
|
gotIdx, _ := results[i]["idx"].(int)
|
||||||
|
if gotIdx != i {
|
||||||
|
t.Fatalf("result index mismatch at %d: got %d", i, gotIdx)
|
||||||
|
}
|
||||||
|
gotID, _ := results[i]["account"].(string)
|
||||||
|
if gotID != accounts[i].Identifier() {
|
||||||
|
t.Fatalf("result order mismatch at %d: got %q want %q", i, gotID, accounts[i].Identifier())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunAccountTestsConcurrentlyRespectsLimit(t *testing.T) {
|
||||||
|
const limit = 3
|
||||||
|
accounts := []config.Account{
|
||||||
|
{Email: "1@example.com"},
|
||||||
|
{Email: "2@example.com"},
|
||||||
|
{Email: "3@example.com"},
|
||||||
|
{Email: "4@example.com"},
|
||||||
|
{Email: "5@example.com"},
|
||||||
|
{Email: "6@example.com"},
|
||||||
|
}
|
||||||
|
var current int32
|
||||||
|
var maxSeen int32
|
||||||
|
_ = runAccountTestsConcurrently(accounts, limit, func(_ int, _ config.Account) map[string]any {
|
||||||
|
c := atomic.AddInt32(¤t, 1)
|
||||||
|
for {
|
||||||
|
m := atomic.LoadInt32(&maxSeen)
|
||||||
|
if c <= m || atomic.CompareAndSwapInt32(&maxSeen, m, c) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
atomic.AddInt32(¤t, -1)
|
||||||
|
return map[string]any{"success": true}
|
||||||
|
})
|
||||||
|
if maxSeen > limit {
|
||||||
|
t.Fatalf("concurrency exceeded limit: got %d > %d", maxSeen, limit)
|
||||||
|
}
|
||||||
|
if maxSeen < 2 {
|
||||||
|
t.Fatalf("expected concurrent execution, max seen %d", maxSeen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -286,20 +286,35 @@ func (s *Store) FindAccount(identifier string) (Account, bool) {
|
|||||||
identifier = strings.TrimSpace(identifier)
|
identifier = strings.TrimSpace(identifier)
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
|
if idx, ok := s.findAccountIndexLocked(identifier); ok {
|
||||||
return s.cfg.Accounts[idx], true
|
return s.cfg.Accounts[idx], true
|
||||||
}
|
}
|
||||||
return Account{}, false
|
return Account{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||||
|
identifier = strings.TrimSpace(identifier)
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if idx, ok := s.accMap[identifier]; ok && idx < len(s.cfg.Accounts) {
|
idx, ok := s.findAccountIndexLocked(identifier)
|
||||||
s.cfg.Accounts[idx].Token = token
|
if !ok {
|
||||||
return s.saveLocked()
|
return errors.New("account not found")
|
||||||
}
|
}
|
||||||
return errors.New("account not found")
|
oldID := s.cfg.Accounts[idx].Identifier()
|
||||||
|
s.cfg.Accounts[idx].Token = token
|
||||||
|
newID := s.cfg.Accounts[idx].Identifier()
|
||||||
|
// Keep historical aliases usable for long-lived queues while also adding
|
||||||
|
// the latest identifier after token refresh.
|
||||||
|
if identifier != "" {
|
||||||
|
s.accMap[identifier] = idx
|
||||||
|
}
|
||||||
|
if oldID != "" {
|
||||||
|
s.accMap[oldID] = idx
|
||||||
|
}
|
||||||
|
if newID != "" {
|
||||||
|
s.accMap[newID] = idx
|
||||||
|
}
|
||||||
|
return s.saveLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) Replace(cfg Config) error {
|
func (s *Store) Replace(cfg Config) error {
|
||||||
@@ -348,6 +363,21 @@ func (s *Store) saveLocked() error {
|
|||||||
return os.WriteFile(s.path, b, 0o644)
|
return os.WriteFile(s.path, b, 0o644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findAccountIndexLocked expects the store lock to already be held.
|
||||||
|
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
|
||||||
|
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
|
||||||
|
return idx, true
|
||||||
|
}
|
||||||
|
// Fallback for token-only accounts whose derived identifier changed after
|
||||||
|
// a token refresh; this preserves correctness on map misses.
|
||||||
|
for i, acc := range s.cfg.Accounts {
|
||||||
|
if acc.Identifier() == identifier {
|
||||||
|
return i, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) IsEnvBacked() bool {
|
func (s *Store) IsEnvBacked() bool {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|||||||
@@ -39,3 +39,34 @@ func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
|||||||
t.Fatalf("unexpected token value: %q", found.Token)
|
t.Fatalf("unexpected token value: %q", found.Token)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||||
|
"accounts":[{"token":"old-token"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
store := LoadStore()
|
||||||
|
before := store.Accounts()
|
||||||
|
if len(before) != 1 {
|
||||||
|
t.Fatalf("expected 1 account, got %d", len(before))
|
||||||
|
}
|
||||||
|
oldID := before[0].Identifier()
|
||||||
|
if oldID == "" {
|
||||||
|
t.Fatal("expected old identifier")
|
||||||
|
}
|
||||||
|
if err := store.UpdateAccountToken(oldID, "new-token"); err != nil {
|
||||||
|
t.Fatalf("update token failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
after := store.Accounts()
|
||||||
|
newID := after[0].Identifier()
|
||||||
|
if newID == "" || newID == oldID {
|
||||||
|
t.Fatalf("expected changed identifier, old=%q new=%q", oldID, newID)
|
||||||
|
}
|
||||||
|
if got, ok := store.FindAccount(newID); !ok || got.Token != "new-token" {
|
||||||
|
t.Fatalf("expected find by new identifier")
|
||||||
|
}
|
||||||
|
if got, ok := store.FindAccount(oldID); !ok || got.Token != "new-token" {
|
||||||
|
t.Fatalf("expected find by old identifier alias")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -90,18 +90,7 @@ func timeout(d time.Duration) func(http.Handler) http.Handler {
|
|||||||
|
|
||||||
func cors(next http.Handler) http.Handler {
|
func cors(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
if origin != "" {
|
|
||||||
// Dynamically reflect the request origin to allow credentials.
|
|
||||||
// Using "*" with Access-Control-Allow-Credentials: true is
|
|
||||||
// invalid per the CORS spec and will be rejected by browsers.
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
w.Header().Set("Vary", "Origin")
|
|
||||||
} else {
|
|
||||||
// No Origin header (e.g. server-to-server requests); allow all.
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
||||||
}
|
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
|
|||||||
@@ -31,22 +31,15 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
|||||||
currentType = "thinking"
|
currentType = "thinking"
|
||||||
}
|
}
|
||||||
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
|
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
|
||||||
chunk, done, ok := ParseDeepSeekSSELine(line)
|
result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType)
|
||||||
if !ok {
|
currentType = result.NextType
|
||||||
|
if !result.Parsed {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if done {
|
if result.Stop {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if _, hasErr := chunk["error"]; hasErr {
|
for _, p := range result.Parts {
|
||||||
return false
|
|
||||||
}
|
|
||||||
parts, finished, newType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
|
||||||
currentType = newType
|
|
||||||
if finished {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, p := range parts {
|
|
||||||
if p.Type == "thinking" {
|
if p.Type == "thinking" {
|
||||||
thinking.WriteString(p.Text)
|
thinking.WriteString(p.Text)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
49
internal/sse/line.go
Normal file
49
internal/sse/line.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package sse
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// LineResult is the normalized parse result for one DeepSeek SSE line.
|
||||||
|
type LineResult struct {
|
||||||
|
Parsed bool
|
||||||
|
Stop bool
|
||||||
|
ContentFilter bool
|
||||||
|
ErrorMessage string
|
||||||
|
Parts []ContentPart
|
||||||
|
NextType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
|
||||||
|
// streaming and non-streaming handlers.
|
||||||
|
func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType string) LineResult {
|
||||||
|
chunk, done, parsed := ParseDeepSeekSSELine(raw)
|
||||||
|
if !parsed {
|
||||||
|
return LineResult{NextType: currentType}
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
return LineResult{Parsed: true, Stop: true, NextType: currentType}
|
||||||
|
}
|
||||||
|
if errObj, hasErr := chunk["error"]; hasErr {
|
||||||
|
return LineResult{
|
||||||
|
Parsed: true,
|
||||||
|
Stop: true,
|
||||||
|
ErrorMessage: fmt.Sprintf("%v", errObj),
|
||||||
|
NextType: currentType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if code, _ := chunk["code"].(string); code == "content_filter" {
|
||||||
|
return LineResult{
|
||||||
|
Parsed: true,
|
||||||
|
Stop: true,
|
||||||
|
ContentFilter: true,
|
||||||
|
ErrorMessage: "content filtered by upstream",
|
||||||
|
NextType: currentType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||||
|
return LineResult{
|
||||||
|
Parsed: true,
|
||||||
|
Stop: finished,
|
||||||
|
Parts: parts,
|
||||||
|
NextType: nextType,
|
||||||
|
}
|
||||||
|
}
|
||||||
37
internal/sse/line_test.go
Normal file
37
internal/sse/line_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package sse
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseDeepSeekContentLineDone(t *testing.T) {
|
||||||
|
res := ParseDeepSeekContentLine([]byte("data: [DONE]"), false, "text")
|
||||||
|
if !res.Parsed || !res.Stop {
|
||||||
|
t.Fatalf("expected parsed stop result: %#v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDeepSeekContentLineError(t *testing.T) {
|
||||||
|
res := ParseDeepSeekContentLine([]byte(`data: {"error":"boom"}`), false, "text")
|
||||||
|
if !res.Parsed || !res.Stop {
|
||||||
|
t.Fatalf("expected stop on error: %#v", res)
|
||||||
|
}
|
||||||
|
if res.ErrorMessage == "" {
|
||||||
|
t.Fatalf("expected non-empty error message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
|
||||||
|
res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text")
|
||||||
|
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||||
|
t.Fatalf("expected content-filter stop result: %#v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||||
|
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
|
||||||
|
if !res.Parsed || res.Stop {
|
||||||
|
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||||
|
}
|
||||||
|
if len(res.Parts) != 1 || res.Parts[0].Text != "hi" || res.Parts[0].Type != "text" {
|
||||||
|
t.Fatalf("unexpected parts: %#v", res.Parts)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user