feat: Implement a waiting queue for account acquisition with configurable limits and updated status reporting.

This commit is contained in:
CJACK
2026-02-16 20:30:21 +08:00
parent 888a0e6bff
commit a6a87853d4
8 changed files with 265 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
package account
import (
"context"
"os"
"sort"
"strconv"
@@ -11,12 +12,14 @@ import (
)
type Pool struct {
store *config.Store
mu sync.Mutex
queue []string
inUse map[string]int
maxInflightPerAccount int
store *config.Store
mu sync.Mutex
queue []string
inUse map[string]int
waiters []chan struct{}
maxInflightPerAccount int
recommendedConcurrency int
maxQueueSize int
}
func NewPool(store *config.Store) *Pool {
@@ -47,25 +50,64 @@ func (p *Pool) Reset() {
}
}
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
queueLimit := maxQueueFromEnv(recommended)
p.mu.Lock()
defer p.mu.Unlock()
p.drainWaitersLocked()
p.queue = ids
p.inUse = map[string]int{}
p.recommendedConcurrency = recommended
p.maxQueueSize = queueLimit
config.Logger.Info(
"[init_account_queue] initialized",
"total", len(ids),
"max_inflight_per_account", p.maxInflightPerAccount,
"recommended_concurrency", p.recommendedConcurrency,
"max_queue_size", p.maxQueueSize,
)
}
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if exclude == nil {
exclude = map[string]bool{}
return p.acquireLocked(target, normalizeExclude(exclude))
}
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
if ctx == nil {
ctx = context.Background()
}
exclude = normalizeExclude(exclude)
for {
if ctx.Err() != nil {
return config.Account{}, false
}
p.mu.Lock()
if acc, ok := p.acquireLocked(target, exclude); ok {
p.mu.Unlock()
return acc, true
}
if !p.canQueueLocked(target, exclude) {
p.mu.Unlock()
return config.Account{}, false
}
waiter := make(chan struct{})
p.waiters = append(p.waiters, waiter)
p.mu.Unlock()
select {
case <-ctx.Done():
p.mu.Lock()
p.removeWaiterLocked(waiter)
p.mu.Unlock()
return config.Account{}, false
case <-waiter:
}
}
}
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
if target != "" {
if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount {
return config.Account{}, false
@@ -131,9 +173,11 @@ func (p *Pool) Release(accountID string) {
}
if count == 1 {
delete(p.inUse, accountID)
p.notifyWaiterLocked()
return
}
p.inUse[accountID] = count - 1
p.notifyWaiterLocked()
}
func (p *Pool) Status() map[string]any {
@@ -162,6 +206,8 @@ func (p *Pool) Status() map[string]any {
"in_use_accounts": inUseAccounts,
"max_inflight_per_account": p.maxInflightPerAccount,
"recommended_concurrency": p.recommendedConcurrency,
"waiting": len(p.waiters),
"max_queue_size": p.maxQueueSize,
}
}
@@ -188,3 +234,69 @@ func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int
}
return accountCount * maxInflightPerAccount
}
func normalizeExclude(exclude map[string]bool) map[string]bool {
if exclude == nil {
return map[string]bool{}
}
return exclude
}
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
if target != "" {
if exclude[target] {
return false
}
if _, ok := p.store.FindAccount(target); !ok {
return false
}
}
if p.maxQueueSize <= 0 {
return false
}
return len(p.waiters) < p.maxQueueSize
}
func (p *Pool) notifyWaiterLocked() {
if len(p.waiters) == 0 {
return
}
waiter := p.waiters[0]
p.waiters = p.waiters[1:]
close(waiter)
}
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
for i, w := range p.waiters {
if w != waiter {
continue
}
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
return true
}
return false
}
func (p *Pool) drainWaitersLocked() {
for _, waiter := range p.waiters {
close(waiter)
}
p.waiters = nil
}
func maxQueueFromEnv(defaultSize int) int {
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n >= 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}

View File

@@ -1,8 +1,10 @@
package account
import (
"context"
"sync"
"testing"
"time"
"ds2api/internal/config"
)
@@ -10,6 +12,9 @@ import (
func newPoolForTest(t *testing.T, maxInflight string) *Pool {
t.Helper()
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", maxInflight)
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["k1"],
"accounts":[
@@ -21,6 +26,33 @@ func newPoolForTest(t *testing.T, maxInflight string) *Pool {
return NewPool(store)
}
func newSingleAccountPoolForTest(t *testing.T, maxInflight string) *Pool {
t.Helper()
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", maxInflight)
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["k1"],
"accounts":[{"email":"acc1@example.com","token":"token1"}]
}`)
return NewPool(config.LoadStore())
}
func waitForWaitingCount(t *testing.T, pool *Pool, want int) {
t.Helper()
deadline := time.Now().Add(800 * time.Millisecond)
for time.Now().Before(deadline) {
status := pool.Status()
if got, ok := status["waiting"].(int); ok && got == want {
return
}
time.Sleep(10 * time.Millisecond)
}
status := pool.Status()
t.Fatalf("waiting count did not reach %d, current status=%v", want, status)
}
func TestPoolRoundRobinWithConcurrentSlots(t *testing.T) {
pool := newPoolForTest(t, "2")
@@ -118,6 +150,9 @@ func TestPoolStatusRecommendedConcurrencyDefault(t *testing.T) {
if got, ok := status["recommended_concurrency"].(int); !ok || got != 4 {
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
}
if got, ok := status["max_queue_size"].(int); !ok || got != 4 {
t.Fatalf("unexpected max_queue_size: %#v", status["max_queue_size"])
}
}
func TestPoolStatusRecommendedConcurrencyRespectsOverride(t *testing.T) {
@@ -130,6 +165,9 @@ func TestPoolStatusRecommendedConcurrencyRespectsOverride(t *testing.T) {
if got, ok := status["recommended_concurrency"].(int); !ok || got != 6 {
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
}
if got, ok := status["max_queue_size"].(int); !ok || got != 6 {
t.Fatalf("unexpected max_queue_size: %#v", status["max_queue_size"])
}
}
func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
@@ -151,6 +189,9 @@ func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
if got, ok := status["recommended_concurrency"].(int); !ok || got != 8 {
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
}
if got, ok := status["max_queue_size"].(int); !ok || got != 8 {
t.Fatalf("unexpected max_queue_size: %#v", status["max_queue_size"])
}
}
func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
@@ -177,3 +218,79 @@ func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
t.Fatalf("unexpected token on acquired account: %q", acc.Token)
}
}
func TestPoolAcquireWaitQueuesAndSucceedsAfterRelease(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
first, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected first acquire to succeed")
}
type result struct {
id string
ok bool
}
resCh := make(chan result, 1)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
go func() {
acc, ok := pool.AcquireWait(ctx, "", nil)
resCh <- result{id: acc.Identifier(), ok: ok}
}()
waitForWaitingCount(t, pool, 1)
pool.Release(first.Identifier())
select {
case res := <-resCh:
if !res.ok {
t.Fatal("expected queued acquire to succeed after release")
}
if res.id != "acc1@example.com" {
t.Fatalf("unexpected account id from queued acquire: %q", res.id)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for queued acquire result")
}
}
func TestPoolAcquireWaitQueueLimitReturnsFalse(t *testing.T) {
pool := newSingleAccountPoolForTest(t, "1")
first, ok := pool.Acquire("", nil)
if !ok {
t.Fatal("expected first acquire to succeed")
}
type result struct {
id string
ok bool
}
firstWaiter := make(chan result, 1)
ctx1, cancel1 := context.WithTimeout(context.Background(), 1200*time.Millisecond)
defer cancel1()
go func() {
acc, ok := pool.AcquireWait(ctx1, "", nil)
firstWaiter <- result{id: acc.Identifier(), ok: ok}
}()
waitForWaitingCount(t, pool, 1)
ctx2, cancel2 := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel2()
start := time.Now()
if _, ok := pool.AcquireWait(ctx2, "", nil); ok {
t.Fatal("expected second queued acquire to fail when queue is full")
}
if time.Since(start) > 120*time.Millisecond {
t.Fatalf("queue-full acquire should fail fast, took %s", time.Since(start))
}
pool.Release(first.Identifier())
select {
case res := <-firstWaiter:
if !res.ok {
t.Fatal("expected first queued acquire to succeed after release")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for first queued acquire")
}
}

View File

@@ -50,7 +50,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
}
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
acc, ok := r.Pool.Acquire(target, nil)
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
if !ok {
return nil, ErrNoAccount
}