mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-16 06:05:07 +08:00
feat: Implement a waiting queue for account acquisition with configurable limits and updated status reporting.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -11,12 +12,14 @@ import (
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]int
|
||||
maxInflightPerAccount int
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]int
|
||||
waiters []chan struct{}
|
||||
maxInflightPerAccount int
|
||||
recommendedConcurrency int
|
||||
maxQueueSize int
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
@@ -47,25 +50,64 @@ func (p *Pool) Reset() {
|
||||
}
|
||||
}
|
||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||
queueLimit := maxQueueFromEnv(recommended)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.drainWaitersLocked()
|
||||
p.queue = ids
|
||||
p.inUse = map[string]int{}
|
||||
p.recommendedConcurrency = recommended
|
||||
p.maxQueueSize = queueLimit
|
||||
config.Logger.Info(
|
||||
"[init_account_queue] initialized",
|
||||
"total", len(ids),
|
||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||
"recommended_concurrency", p.recommendedConcurrency,
|
||||
"max_queue_size", p.maxQueueSize,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if exclude == nil {
|
||||
exclude = map[string]bool{}
|
||||
return p.acquireLocked(target, normalizeExclude(exclude))
|
||||
}
|
||||
|
||||
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
exclude = normalizeExclude(exclude)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if acc, ok := p.acquireLocked(target, exclude); ok {
|
||||
p.mu.Unlock()
|
||||
return acc, true
|
||||
}
|
||||
if !p.canQueueLocked(target, exclude) {
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
}
|
||||
waiter := make(chan struct{})
|
||||
p.waiters = append(p.waiters, waiter)
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.mu.Lock()
|
||||
p.removeWaiterLocked(waiter)
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
case <-waiter:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if target != "" {
|
||||
if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount {
|
||||
return config.Account{}, false
|
||||
@@ -131,9 +173,11 @@ func (p *Pool) Release(accountID string) {
|
||||
}
|
||||
if count == 1 {
|
||||
delete(p.inUse, accountID)
|
||||
p.notifyWaiterLocked()
|
||||
return
|
||||
}
|
||||
p.inUse[accountID] = count - 1
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func (p *Pool) Status() map[string]any {
|
||||
@@ -162,6 +206,8 @@ func (p *Pool) Status() map[string]any {
|
||||
"in_use_accounts": inUseAccounts,
|
||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||
"recommended_concurrency": p.recommendedConcurrency,
|
||||
"waiting": len(p.waiters),
|
||||
"max_queue_size": p.maxQueueSize,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,3 +234,69 @@ func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int
|
||||
}
|
||||
return accountCount * maxInflightPerAccount
|
||||
}
|
||||
|
||||
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
||||
if exclude == nil {
|
||||
return map[string]bool{}
|
||||
}
|
||||
return exclude
|
||||
}
|
||||
|
||||
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
||||
if target != "" {
|
||||
if exclude[target] {
|
||||
return false
|
||||
}
|
||||
if _, ok := p.store.FindAccount(target); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.maxQueueSize <= 0 {
|
||||
return false
|
||||
}
|
||||
return len(p.waiters) < p.maxQueueSize
|
||||
}
|
||||
|
||||
func (p *Pool) notifyWaiterLocked() {
|
||||
if len(p.waiters) == 0 {
|
||||
return
|
||||
}
|
||||
waiter := p.waiters[0]
|
||||
p.waiters = p.waiters[1:]
|
||||
close(waiter)
|
||||
}
|
||||
|
||||
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
||||
for i, w := range p.waiters {
|
||||
if w != waiter {
|
||||
continue
|
||||
}
|
||||
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Pool) drainWaitersLocked() {
|
||||
for _, waiter := range p.waiters {
|
||||
close(waiter)
|
||||
}
|
||||
p.waiters = nil
|
||||
}
|
||||
|
||||
func maxQueueFromEnv(defaultSize int) int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
@@ -10,6 +12,9 @@ import (
|
||||
func newPoolForTest(t *testing.T, maxInflight string) *Pool {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", maxInflight)
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[
|
||||
@@ -21,6 +26,33 @@ func newPoolForTest(t *testing.T, maxInflight string) *Pool {
|
||||
return NewPool(store)
|
||||
}
|
||||
|
||||
func newSingleAccountPoolForTest(t *testing.T, maxInflight string) *Pool {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", maxInflight)
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"email":"acc1@example.com","token":"token1"}]
|
||||
}`)
|
||||
return NewPool(config.LoadStore())
|
||||
}
|
||||
|
||||
func waitForWaitingCount(t *testing.T, pool *Pool, want int) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(800 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
status := pool.Status()
|
||||
if got, ok := status["waiting"].(int); ok && got == want {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
status := pool.Status()
|
||||
t.Fatalf("waiting count did not reach %d, current status=%v", want, status)
|
||||
}
|
||||
|
||||
func TestPoolRoundRobinWithConcurrentSlots(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
|
||||
@@ -118,6 +150,9 @@ func TestPoolStatusRecommendedConcurrencyDefault(t *testing.T) {
|
||||
if got, ok := status["recommended_concurrency"].(int); !ok || got != 4 {
|
||||
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
|
||||
}
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 4 {
|
||||
t.Fatalf("unexpected max_queue_size: %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolStatusRecommendedConcurrencyRespectsOverride(t *testing.T) {
|
||||
@@ -130,6 +165,9 @@ func TestPoolStatusRecommendedConcurrencyRespectsOverride(t *testing.T) {
|
||||
if got, ok := status["recommended_concurrency"].(int); !ok || got != 6 {
|
||||
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
|
||||
}
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 6 {
|
||||
t.Fatalf("unexpected max_queue_size: %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
|
||||
@@ -151,6 +189,9 @@ func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
|
||||
if got, ok := status["recommended_concurrency"].(int); !ok || got != 8 {
|
||||
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
|
||||
}
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 8 {
|
||||
t.Fatalf("unexpected max_queue_size: %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
@@ -177,3 +218,79 @@ func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
t.Fatalf("unexpected token on acquired account: %q", acc.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitQueuesAndSucceedsAfterRelease(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
first, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected first acquire to succeed")
|
||||
}
|
||||
|
||||
type result struct {
|
||||
id string
|
||||
ok bool
|
||||
}
|
||||
resCh := make(chan result, 1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
go func() {
|
||||
acc, ok := pool.AcquireWait(ctx, "", nil)
|
||||
resCh <- result{id: acc.Identifier(), ok: ok}
|
||||
}()
|
||||
|
||||
waitForWaitingCount(t, pool, 1)
|
||||
pool.Release(first.Identifier())
|
||||
|
||||
select {
|
||||
case res := <-resCh:
|
||||
if !res.ok {
|
||||
t.Fatal("expected queued acquire to succeed after release")
|
||||
}
|
||||
if res.id != "acc1@example.com" {
|
||||
t.Fatalf("unexpected account id from queued acquire: %q", res.id)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for queued acquire result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitQueueLimitReturnsFalse(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
first, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected first acquire to succeed")
|
||||
}
|
||||
|
||||
type result struct {
|
||||
id string
|
||||
ok bool
|
||||
}
|
||||
firstWaiter := make(chan result, 1)
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 1200*time.Millisecond)
|
||||
defer cancel1()
|
||||
go func() {
|
||||
acc, ok := pool.AcquireWait(ctx1, "", nil)
|
||||
firstWaiter <- result{id: acc.Identifier(), ok: ok}
|
||||
}()
|
||||
waitForWaitingCount(t, pool, 1)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel2()
|
||||
start := time.Now()
|
||||
if _, ok := pool.AcquireWait(ctx2, "", nil); ok {
|
||||
t.Fatal("expected second queued acquire to fail when queue is full")
|
||||
}
|
||||
if time.Since(start) > 120*time.Millisecond {
|
||||
t.Fatalf("queue-full acquire should fail fast, took %s", time.Since(start))
|
||||
}
|
||||
|
||||
pool.Release(first.Identifier())
|
||||
select {
|
||||
case res := <-firstWaiter:
|
||||
if !res.ok {
|
||||
t.Fatal("expected first queued acquire to succeed after release")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for first queued acquire")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
|
||||
}
|
||||
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
|
||||
acc, ok := r.Pool.Acquire(target, nil)
|
||||
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
|
||||
if !ok {
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user