mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-16 22:25:15 +08:00
feat: Implement DeepSeek integration, refactor model adapters for streaming and tool calls, enhance admin and account management, and introduce new UI features for settings, API testing, and Vercel sync.
This commit is contained in:
@@ -1,363 +0,0 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]int
|
||||
waiters []chan struct{}
|
||||
maxInflightPerAccount int
|
||||
recommendedConcurrency int
|
||||
maxQueueSize int
|
||||
globalMaxInflight int
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
maxPer := 2
|
||||
if store != nil {
|
||||
maxPer = store.RuntimeAccountMaxInflight()
|
||||
}
|
||||
p := &Pool{
|
||||
store: store,
|
||||
inUse: map[string]int{},
|
||||
maxInflightPerAccount: maxPer,
|
||||
}
|
||||
p.Reset()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) Reset() {
|
||||
accounts := p.store.Accounts()
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
iHas := accounts[i].Token != ""
|
||||
jHas := accounts[j].Token != ""
|
||||
if iHas == jHas {
|
||||
return i < j
|
||||
}
|
||||
return iHas
|
||||
})
|
||||
ids := make([]string, 0, len(accounts))
|
||||
for _, a := range accounts {
|
||||
id := a.Identifier()
|
||||
if id != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if p.store != nil {
|
||||
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
|
||||
} else {
|
||||
p.maxInflightPerAccount = maxInflightFromEnv()
|
||||
}
|
||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||
queueLimit := maxQueueFromEnv(recommended)
|
||||
globalLimit := recommended
|
||||
if p.store != nil {
|
||||
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
|
||||
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.drainWaitersLocked()
|
||||
p.queue = ids
|
||||
p.inUse = map[string]int{}
|
||||
p.recommendedConcurrency = recommended
|
||||
p.maxQueueSize = queueLimit
|
||||
p.globalMaxInflight = globalLimit
|
||||
config.Logger.Info(
|
||||
"[init_account_queue] initialized",
|
||||
"total", len(ids),
|
||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||
"global_max_inflight", p.globalMaxInflight,
|
||||
"recommended_concurrency", p.recommendedConcurrency,
|
||||
"max_queue_size", p.maxQueueSize,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.acquireLocked(target, normalizeExclude(exclude))
|
||||
}
|
||||
|
||||
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
exclude = normalizeExclude(exclude)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if acc, ok := p.acquireLocked(target, exclude); ok {
|
||||
p.mu.Unlock()
|
||||
return acc, true
|
||||
}
|
||||
if !p.canQueueLocked(target, exclude) {
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
}
|
||||
waiter := make(chan struct{})
|
||||
p.waiters = append(p.waiters, waiter)
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.mu.Lock()
|
||||
p.removeWaiterLocked(waiter)
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
case <-waiter:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if target != "" {
|
||||
if exclude[target] || !p.canAcquireIDLocked(target) {
|
||||
return config.Account{}, false
|
||||
}
|
||||
acc, ok := p.store.FindAccount(target)
|
||||
if !ok {
|
||||
return config.Account{}, false
|
||||
}
|
||||
p.inUse[target]++
|
||||
p.bumpQueue(target)
|
||||
return acc, true
|
||||
}
|
||||
|
||||
if acc, ok := p.tryAcquire(exclude, true); ok {
|
||||
return acc, true
|
||||
}
|
||||
if acc, ok := p.tryAcquire(exclude, false); ok {
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] || !p.canAcquireIDLocked(id) {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if requireToken && acc.Token == "" {
|
||||
continue
|
||||
}
|
||||
p.inUse[id]++
|
||||
p.bumpQueue(id)
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) bumpQueue(accountID string) {
|
||||
for i, id := range p.queue {
|
||||
if id != accountID {
|
||||
continue
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.queue = append(p.queue, accountID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) Release(accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
count := p.inUse[accountID]
|
||||
if count <= 0 {
|
||||
return
|
||||
}
|
||||
if count == 1 {
|
||||
delete(p.inUse, accountID)
|
||||
p.notifyWaiterLocked()
|
||||
return
|
||||
}
|
||||
p.inUse[accountID] = count - 1
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func (p *Pool) Status() map[string]any {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
available := make([]string, 0, len(p.queue))
|
||||
inUseAccounts := make([]string, 0, len(p.inUse))
|
||||
inUseSlots := 0
|
||||
for _, id := range p.queue {
|
||||
if p.inUse[id] < p.maxInflightPerAccount {
|
||||
available = append(available, id)
|
||||
}
|
||||
}
|
||||
for id, count := range p.inUse {
|
||||
if count > 0 {
|
||||
inUseAccounts = append(inUseAccounts, id)
|
||||
inUseSlots += count
|
||||
}
|
||||
}
|
||||
sort.Strings(inUseAccounts)
|
||||
return map[string]any{
|
||||
"available": len(available),
|
||||
"in_use": inUseSlots,
|
||||
"total": len(p.store.Accounts()),
|
||||
"available_accounts": available,
|
||||
"in_use_accounts": inUseAccounts,
|
||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||
"global_max_inflight": p.globalMaxInflight,
|
||||
"recommended_concurrency": p.recommendedConcurrency,
|
||||
"waiting": len(p.waiters),
|
||||
"max_queue_size": p.maxQueueSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 1
|
||||
}
|
||||
if maxQueueSize < 0 {
|
||||
maxQueueSize = 0
|
||||
}
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount
|
||||
}
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.maxInflightPerAccount = maxInflightPerAccount
|
||||
p.maxQueueSize = maxQueueSize
|
||||
p.globalMaxInflight = globalMaxInflight
|
||||
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func maxInflightFromEnv() int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
|
||||
if accountCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 2
|
||||
}
|
||||
return accountCount * maxInflightPerAccount
|
||||
}
|
||||
|
||||
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
||||
if exclude == nil {
|
||||
return map[string]bool{}
|
||||
}
|
||||
return exclude
|
||||
}
|
||||
|
||||
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
||||
if target != "" {
|
||||
if exclude[target] {
|
||||
return false
|
||||
}
|
||||
if _, ok := p.store.FindAccount(target); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.maxQueueSize <= 0 {
|
||||
return false
|
||||
}
|
||||
return len(p.waiters) < p.maxQueueSize
|
||||
}
|
||||
|
||||
func (p *Pool) notifyWaiterLocked() {
|
||||
if len(p.waiters) == 0 {
|
||||
return
|
||||
}
|
||||
waiter := p.waiters[0]
|
||||
p.waiters = p.waiters[1:]
|
||||
close(waiter)
|
||||
}
|
||||
|
||||
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
||||
for i, w := range p.waiters {
|
||||
if w != waiter {
|
||||
continue
|
||||
}
|
||||
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Pool) drainWaitersLocked() {
|
||||
for _, waiter := range p.waiters {
|
||||
close(waiter)
|
||||
}
|
||||
p.waiters = nil
|
||||
}
|
||||
|
||||
func maxQueueFromEnv(defaultSize int) int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (p *Pool) canAcquireIDLocked(accountID string) bool {
|
||||
if accountID == "" {
|
||||
return false
|
||||
}
|
||||
if p.inUse[accountID] >= p.maxInflightPerAccount {
|
||||
return false
|
||||
}
|
||||
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Pool) currentInUseLocked() int {
|
||||
total := 0
|
||||
for _, n := range p.inUse {
|
||||
total += n
|
||||
}
|
||||
return total
|
||||
}
|
||||
108
internal/account/pool_acquire.go
Normal file
108
internal/account/pool_acquire.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.acquireLocked(target, normalizeExclude(exclude))
|
||||
}
|
||||
|
||||
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
exclude = normalizeExclude(exclude)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if acc, ok := p.acquireLocked(target, exclude); ok {
|
||||
p.mu.Unlock()
|
||||
return acc, true
|
||||
}
|
||||
if !p.canQueueLocked(target, exclude) {
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
}
|
||||
waiter := make(chan struct{})
|
||||
p.waiters = append(p.waiters, waiter)
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.mu.Lock()
|
||||
p.removeWaiterLocked(waiter)
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
case <-waiter:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if target != "" {
|
||||
if exclude[target] || !p.canAcquireIDLocked(target) {
|
||||
return config.Account{}, false
|
||||
}
|
||||
acc, ok := p.store.FindAccount(target)
|
||||
if !ok {
|
||||
return config.Account{}, false
|
||||
}
|
||||
p.inUse[target]++
|
||||
p.bumpQueue(target)
|
||||
return acc, true
|
||||
}
|
||||
|
||||
if acc, ok := p.tryAcquire(exclude, true); ok {
|
||||
return acc, true
|
||||
}
|
||||
if acc, ok := p.tryAcquire(exclude, false); ok {
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] || !p.canAcquireIDLocked(id) {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if requireToken && acc.Token == "" {
|
||||
continue
|
||||
}
|
||||
p.inUse[id]++
|
||||
p.bumpQueue(id)
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) bumpQueue(accountID string) {
|
||||
for i, id := range p.queue {
|
||||
if id != accountID {
|
||||
continue
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.queue = append(p.queue, accountID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
||||
if exclude == nil {
|
||||
return map[string]bool{}
|
||||
}
|
||||
return exclude
|
||||
}
|
||||
132
internal/account/pool_core.go
Normal file
132
internal/account/pool_core.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]int
|
||||
waiters []chan struct{}
|
||||
maxInflightPerAccount int
|
||||
recommendedConcurrency int
|
||||
maxQueueSize int
|
||||
globalMaxInflight int
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
maxPer := 2
|
||||
if store != nil {
|
||||
maxPer = store.RuntimeAccountMaxInflight()
|
||||
}
|
||||
p := &Pool{
|
||||
store: store,
|
||||
inUse: map[string]int{},
|
||||
maxInflightPerAccount: maxPer,
|
||||
}
|
||||
p.Reset()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) Reset() {
|
||||
accounts := p.store.Accounts()
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
iHas := accounts[i].Token != ""
|
||||
jHas := accounts[j].Token != ""
|
||||
if iHas == jHas {
|
||||
return i < j
|
||||
}
|
||||
return iHas
|
||||
})
|
||||
ids := make([]string, 0, len(accounts))
|
||||
for _, a := range accounts {
|
||||
id := a.Identifier()
|
||||
if id != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if p.store != nil {
|
||||
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
|
||||
} else {
|
||||
p.maxInflightPerAccount = maxInflightFromEnv()
|
||||
}
|
||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||
queueLimit := maxQueueFromEnv(recommended)
|
||||
globalLimit := recommended
|
||||
if p.store != nil {
|
||||
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
|
||||
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.drainWaitersLocked()
|
||||
p.queue = ids
|
||||
p.inUse = map[string]int{}
|
||||
p.recommendedConcurrency = recommended
|
||||
p.maxQueueSize = queueLimit
|
||||
p.globalMaxInflight = globalLimit
|
||||
config.Logger.Info(
|
||||
"[init_account_queue] initialized",
|
||||
"total", len(ids),
|
||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||
"global_max_inflight", p.globalMaxInflight,
|
||||
"recommended_concurrency", p.recommendedConcurrency,
|
||||
"max_queue_size", p.maxQueueSize,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Pool) Release(accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
count := p.inUse[accountID]
|
||||
if count <= 0 {
|
||||
return
|
||||
}
|
||||
if count == 1 {
|
||||
delete(p.inUse, accountID)
|
||||
p.notifyWaiterLocked()
|
||||
return
|
||||
}
|
||||
p.inUse[accountID] = count - 1
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func (p *Pool) Status() map[string]any {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
available := make([]string, 0, len(p.queue))
|
||||
inUseAccounts := make([]string, 0, len(p.inUse))
|
||||
inUseSlots := 0
|
||||
for _, id := range p.queue {
|
||||
if p.inUse[id] < p.maxInflightPerAccount {
|
||||
available = append(available, id)
|
||||
}
|
||||
}
|
||||
for id, count := range p.inUse {
|
||||
if count > 0 {
|
||||
inUseAccounts = append(inUseAccounts, id)
|
||||
inUseSlots += count
|
||||
}
|
||||
}
|
||||
sort.Strings(inUseAccounts)
|
||||
return map[string]any{
|
||||
"available": len(available),
|
||||
"in_use": inUseSlots,
|
||||
"total": len(p.store.Accounts()),
|
||||
"available_accounts": available,
|
||||
"in_use_accounts": inUseAccounts,
|
||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||
"global_max_inflight": p.globalMaxInflight,
|
||||
"recommended_concurrency": p.recommendedConcurrency,
|
||||
"waiting": len(p.waiters),
|
||||
"max_queue_size": p.maxQueueSize,
|
||||
}
|
||||
}
|
||||
91
internal/account/pool_limits.go
Normal file
91
internal/account/pool_limits.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 1
|
||||
}
|
||||
if maxQueueSize < 0 {
|
||||
maxQueueSize = 0
|
||||
}
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount
|
||||
}
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.maxInflightPerAccount = maxInflightPerAccount
|
||||
p.maxQueueSize = maxQueueSize
|
||||
p.globalMaxInflight = globalMaxInflight
|
||||
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func maxInflightFromEnv() int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
|
||||
if accountCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 2
|
||||
}
|
||||
return accountCount * maxInflightPerAccount
|
||||
}
|
||||
|
||||
func maxQueueFromEnv(defaultSize int) int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (p *Pool) canAcquireIDLocked(accountID string) bool {
|
||||
if accountID == "" {
|
||||
return false
|
||||
}
|
||||
if p.inUse[accountID] >= p.maxInflightPerAccount {
|
||||
return false
|
||||
}
|
||||
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Pool) currentInUseLocked() int {
|
||||
total := 0
|
||||
for _, n := range p.inUse {
|
||||
total += n
|
||||
}
|
||||
return total
|
||||
}
|
||||
43
internal/account/pool_waiters.go
Normal file
43
internal/account/pool_waiters.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package account
|
||||
|
||||
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
||||
if target != "" {
|
||||
if exclude[target] {
|
||||
return false
|
||||
}
|
||||
if _, ok := p.store.FindAccount(target); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.maxQueueSize <= 0 {
|
||||
return false
|
||||
}
|
||||
return len(p.waiters) < p.maxQueueSize
|
||||
}
|
||||
|
||||
func (p *Pool) notifyWaiterLocked() {
|
||||
if len(p.waiters) == 0 {
|
||||
return
|
||||
}
|
||||
waiter := p.waiters[0]
|
||||
p.waiters = p.waiters[1:]
|
||||
close(waiter)
|
||||
}
|
||||
|
||||
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
||||
for i, w := range p.waiters {
|
||||
if w != waiter {
|
||||
continue
|
||||
}
|
||||
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Pool) drainWaitersLocked() {
|
||||
for _, waiter := range p.waiters {
|
||||
close(waiter)
|
||||
}
|
||||
p.waiters = nil
|
||||
}
|
||||
@@ -32,4 +32,3 @@ func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
var (
|
||||
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
||||
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
||||
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
||||
)
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/v1/messages", h.Messages)
|
||||
r.Post("/messages", h.Messages)
|
||||
r.Post("/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeClaudeError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq := norm.Standard
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||
return
|
||||
}
|
||||
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
respBody := claudefmt.BuildMessageResponse(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
stdReq.ResponseModel,
|
||||
norm.NormalizedMessages,
|
||||
result.Thinking,
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
if sys, ok := req["system"].(string); ok {
|
||||
inputTokens += util.EstimateTokens(sys)
|
||||
}
|
||||
for _, item := range messages {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTokens += 2
|
||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok {
|
||||
for _, t := range tools {
|
||||
b, _ := json.Marshal(t)
|
||||
inputTokens += util.EstimateTokens(string(b))
|
||||
}
|
||||
}
|
||||
if inputTokens < 1 {
|
||||
inputTokens = 1
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
streamRuntime := newClaudeStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
model,
|
||||
messages,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: claudeStreamPingInterval,
|
||||
IdleTimeout: claudeStreamIdleTimeout,
|
||||
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendPing()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: streamRuntime.onFinalize,
|
||||
})
|
||||
}
|
||||
|
||||
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||
code := "invalid_request"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
code = "authentication_failed"
|
||||
case http.StatusTooManyRequests:
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
code = "not_found"
|
||||
case http.StatusInternalServerError:
|
||||
code = "internal_error"
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := b["type"].(string)
|
||||
if typeStr == "text" {
|
||||
if t, ok := b["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
}
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schema, _ := json.Marshal(m["input_schema"])
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts,
|
||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
||||
if content == "" {
|
||||
content = "null"
|
||||
}
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := m["name"].(string); ok && name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
25
internal/adapter/claude/handler_errors.go
Normal file
25
internal/adapter/claude/handler_errors.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package claude
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||
code := "invalid_request"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
code = "authentication_failed"
|
||||
case http.StatusTooManyRequests:
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
code = "not_found"
|
||||
case http.StatusInternalServerError:
|
||||
code = "internal_error"
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
134
internal/adapter/claude/handler_messages.go
Normal file
134
internal/adapter/claude/handler_messages.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeClaudeError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq := norm.Standard
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||
return
|
||||
}
|
||||
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
respBody := claudefmt.BuildMessageResponse(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
stdReq.ResponseModel,
|
||||
norm.NormalizedMessages,
|
||||
result.Thinking,
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
streamRuntime := newClaudeStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
model,
|
||||
messages,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: claudeStreamPingInterval,
|
||||
IdleTimeout: claudeStreamIdleTimeout,
|
||||
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendPing()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: streamRuntime.onFinalize,
|
||||
})
|
||||
}
|
||||
41
internal/adapter/claude/handler_routes.go
Normal file
41
internal/adapter/claude/handler_routes.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
var (
|
||||
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
||||
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
||||
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
||||
)
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/v1/messages", h.Messages)
|
||||
r.Post("/messages", h.Messages)
|
||||
r.Post("/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||
}
|
||||
51
internal/adapter/claude/handler_tokens.go
Normal file
51
internal/adapter/claude/handler_tokens.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
if sys, ok := req["system"].(string); ok {
|
||||
inputTokens += util.EstimateTokens(sys)
|
||||
}
|
||||
for _, item := range messages {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTokens += 2
|
||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok {
|
||||
for _, t := range tools {
|
||||
b, _ := json.Marshal(t)
|
||||
inputTokens += util.EstimateTokens(string(b))
|
||||
}
|
||||
}
|
||||
if inputTokens < 1 {
|
||||
inputTokens = 1
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
143
internal/adapter/claude/handler_utils.go
Normal file
143
internal/adapter/claude/handler_utils.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := b["type"].(string)
|
||||
if typeStr == "text" {
|
||||
if t, ok := b["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
}
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schema, _ := json.Marshal(m["input_schema"])
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts,
|
||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
||||
if content == "" {
|
||||
content = "null"
|
||||
}
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := m["name"].(string); ok && name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,308 +0,0 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
toolNames []string
|
||||
messages []any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferToolContent bool
|
||||
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
thinkingBlockIndex int
|
||||
textBlockOpen bool
|
||||
textBlockIndex int
|
||||
ended bool
|
||||
upstreamErr string
|
||||
}
|
||||
|
||||
func newClaudeStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
messages []any,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
messages: messages,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) send(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("event: "))
|
||||
_, _ = s.w.Write([]byte(event))
|
||||
_, _ = s.w.Write([]byte("\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendError(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
s.send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendPing() {
|
||||
s.send("ping", map[string]any{"type": "ping"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendMessageStart() {
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
|
||||
s.send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": s.messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": s.model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
if !s.thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.thinkingBlockIndex,
|
||||
})
|
||||
s.thinkingBlockOpen = false
|
||||
s.thinkingBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeTextBlock() {
|
||||
if !s.textBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.textBlockIndex,
|
||||
})
|
||||
s.textBlockOpen = false
|
||||
s.textBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
if s.ended {
|
||||
return
|
||||
}
|
||||
s.ended = true
|
||||
|
||||
s.closeThinkingBlock()
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := s.nextBlockIndex + i
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
s.nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.closeTextBlock()
|
||||
if !s.thinkingBlockOpen {
|
||||
s.thinkingBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
s.thinkingBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": p.Text,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferToolContent {
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
if !s.textBlockOpen {
|
||||
s.textBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.textBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": p.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||
if string(reason) == "upstream_error" {
|
||||
s.sendError(s.upstreamErr)
|
||||
return
|
||||
}
|
||||
if scannerErr != nil {
|
||||
s.sendError(scannerErr.Error())
|
||||
return
|
||||
}
|
||||
s.finalize("end_turn")
|
||||
}
|
||||
146
internal/adapter/claude/stream_runtime_core.go
Normal file
146
internal/adapter/claude/stream_runtime_core.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
type claudeStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
toolNames []string
|
||||
messages []any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferToolContent bool
|
||||
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
thinkingBlockIndex int
|
||||
textBlockOpen bool
|
||||
textBlockIndex int
|
||||
ended bool
|
||||
upstreamErr string
|
||||
}
|
||||
|
||||
func newClaudeStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
messages []any,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
messages: messages,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.closeTextBlock()
|
||||
if !s.thinkingBlockOpen {
|
||||
s.thinkingBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
s.thinkingBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": p.Text,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferToolContent {
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
if !s.textBlockOpen {
|
||||
s.textBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.textBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": p.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
59
internal/adapter/claude/stream_runtime_emit.go
Normal file
59
internal/adapter/claude/stream_runtime_emit.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) send(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("event: "))
|
||||
_, _ = s.w.Write([]byte(event))
|
||||
_, _ = s.w.Write([]byte("\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendError(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
s.send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendPing() {
|
||||
s.send("ping", map[string]any{"type": "ping"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendMessageStart() {
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
|
||||
s.send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": s.messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": s.model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
}
|
||||
119
internal/adapter/claude/stream_runtime_finalize.go
Normal file
119
internal/adapter/claude/stream_runtime_finalize.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
if !s.thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.thinkingBlockIndex,
|
||||
})
|
||||
s.thinkingBlockOpen = false
|
||||
s.thinkingBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeTextBlock() {
|
||||
if !s.textBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.textBlockIndex,
|
||||
})
|
||||
s.textBlockOpen = false
|
||||
s.textBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
if s.ended {
|
||||
return
|
||||
}
|
||||
s.ended = true
|
||||
|
||||
s.closeThinkingBlock()
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := s.nextBlockIndex + i
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
s.nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||
if string(reason) == "upstream_error" {
|
||||
s.sendError(s.upstreamErr)
|
||||
return
|
||||
}
|
||||
if scannerErr != nil {
|
||||
s.sendError(scannerErr.Error())
|
||||
return
|
||||
}
|
||||
s.finalize("end_turn")
|
||||
}
|
||||
@@ -1,313 +0,0 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
|
||||
requestedModel := strings.TrimSpace(routeModel)
|
||||
if requestedModel == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
|
||||
}
|
||||
|
||||
resolvedModel, ok := config.ResolveModel(store, requestedModel)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
messagesRaw := geminiMessagesFromRequest(req)
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
|
||||
}
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "google_gemini",
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: requestedModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: stream,
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
"content": sys,
|
||||
})
|
||||
}
|
||||
|
||||
contents, _ := req["contents"].([]any)
|
||||
for _, item := range contents {
|
||||
content, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := mapGeminiRole(content["role"])
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
if text := strings.TrimSpace(asString(content["text"])); text != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
textParts := make([]string, 0, len(parts))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
|
||||
for _, rawPart := range parts {
|
||||
part, ok := rawPart.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
textParts = append(textParts, text)
|
||||
continue
|
||||
}
|
||||
|
||||
if fnCall, ok := part["functionCall"].(map[string]any); ok {
|
||||
flushText()
|
||||
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyJSON(fnCall["args"]),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
|
||||
flushText()
|
||||
name := strings.TrimSpace(asString(fnResp["name"]))
|
||||
callID := strings.TrimSpace(asString(fnResp["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["callId"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
content := fnResp["response"]
|
||||
if content == nil {
|
||||
content = fnResp["output"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
msg := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": callID,
|
||||
"content": content,
|
||||
}
|
||||
if name != "" {
|
||||
msg["name"] = name
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeGeminiSystemInstruction(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if parts, ok := v["parts"].([]any); ok {
|
||||
texts := make([]string, 0, len(parts))
|
||||
for _, item := range parts {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
if text := strings.TrimSpace(asString(v["text"])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func mapGeminiRole(v any) string {
|
||||
switch strings.ToLower(strings.TrimSpace(asString(v))) {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model", "assistant":
|
||||
return "assistant"
|
||||
case "system":
|
||||
return "system"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func convertGeminiTools(raw any) []any {
|
||||
tools, _ := raw.([]any)
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(tools))
|
||||
for _, item := range tools {
|
||||
tool, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
|
||||
for _, declRaw := range fnDecls {
|
||||
decl, ok := declRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(decl["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
function := map[string]any{
|
||||
"name": name,
|
||||
}
|
||||
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
|
||||
function["description"] = desc
|
||||
}
|
||||
if params, ok := decl["parameters"].(map[string]any); ok {
|
||||
function["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": function,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// OpenAI-style passthrough fallback.
|
||||
if _, ok := tool["function"].(map[string]any); ok {
|
||||
out = append(out, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
// Loose fallback for flattened function schema objects.
|
||||
name := strings.TrimSpace(asString(tool["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
fn := map[string]any{"name": name}
|
||||
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
|
||||
fn["description"] = desc
|
||||
}
|
||||
if params, ok := tool["parameters"].(map[string]any); ok {
|
||||
fn["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": fn,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func collectGeminiPassThrough(req map[string]any) map[string]any {
|
||||
cfg, _ := req["generationConfig"].(map[string]any)
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
if v, ok := cfg["temperature"]; ok {
|
||||
out["temperature"] = v
|
||||
}
|
||||
if v, ok := cfg["topP"]; ok {
|
||||
out["top_p"] = v
|
||||
}
|
||||
if v, ok := cfg["maxOutputTokens"]; ok {
|
||||
out["max_tokens"] = v
|
||||
}
|
||||
if v, ok := cfg["stopSequences"]; ok {
|
||||
out["stop"] = v
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func stringifyJSON(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
153
internal/adapter/gemini/convert_messages.go
Normal file
153
internal/adapter/gemini/convert_messages.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
"content": sys,
|
||||
})
|
||||
}
|
||||
|
||||
contents, _ := req["contents"].([]any)
|
||||
for _, item := range contents {
|
||||
content, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := mapGeminiRole(content["role"])
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
if text := strings.TrimSpace(asString(content["text"])); text != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
textParts := make([]string, 0, len(parts))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
|
||||
for _, rawPart := range parts {
|
||||
part, ok := rawPart.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
textParts = append(textParts, text)
|
||||
continue
|
||||
}
|
||||
|
||||
if fnCall, ok := part["functionCall"].(map[string]any); ok {
|
||||
flushText()
|
||||
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyJSON(fnCall["args"]),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
|
||||
flushText()
|
||||
name := strings.TrimSpace(asString(fnResp["name"]))
|
||||
callID := strings.TrimSpace(asString(fnResp["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["callId"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
content := fnResp["response"]
|
||||
if content == nil {
|
||||
content = fnResp["output"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
msg := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": callID,
|
||||
"content": content,
|
||||
}
|
||||
if name != "" {
|
||||
msg["name"] = name
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeGeminiSystemInstruction(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if parts, ok := v["parts"].([]any); ok {
|
||||
texts := make([]string, 0, len(parts))
|
||||
for _, item := range parts {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
if text := strings.TrimSpace(asString(v["text"])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func mapGeminiRole(v any) string {
|
||||
switch strings.ToLower(strings.TrimSpace(asString(v))) {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model", "assistant":
|
||||
return "assistant"
|
||||
case "system":
|
||||
return "system"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
54
internal/adapter/gemini/convert_passthrough.go
Normal file
54
internal/adapter/gemini/convert_passthrough.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func collectGeminiPassThrough(req map[string]any) map[string]any {
|
||||
cfg, _ := req["generationConfig"].(map[string]any)
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
if v, ok := cfg["temperature"]; ok {
|
||||
out["temperature"] = v
|
||||
}
|
||||
if v, ok := cfg["topP"]; ok {
|
||||
out["top_p"] = v
|
||||
}
|
||||
if v, ok := cfg["maxOutputTokens"]; ok {
|
||||
out["max_tokens"] = v
|
||||
}
|
||||
if v, ok := cfg["stopSequences"]; ok {
|
||||
out["stop"] = v
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func stringifyJSON(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
46
internal/adapter/gemini/convert_request.go
Normal file
46
internal/adapter/gemini/convert_request.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
|
||||
requestedModel := strings.TrimSpace(routeModel)
|
||||
if requestedModel == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
|
||||
}
|
||||
|
||||
resolvedModel, ok := config.ResolveModel(store, requestedModel)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
messagesRaw := geminiMessagesFromRequest(req)
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
|
||||
}
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "google_gemini",
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: requestedModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: stream,
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
71
internal/adapter/gemini/convert_tools.go
Normal file
71
internal/adapter/gemini/convert_tools.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
func convertGeminiTools(raw any) []any {
|
||||
tools, _ := raw.([]any)
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(tools))
|
||||
for _, item := range tools {
|
||||
tool, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
|
||||
for _, declRaw := range fnDecls {
|
||||
decl, ok := declRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(decl["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
function := map[string]any{
|
||||
"name": name,
|
||||
}
|
||||
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
|
||||
function["description"] = desc
|
||||
}
|
||||
if params, ok := decl["parameters"].(map[string]any); ok {
|
||||
function["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": function,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// OpenAI-style passthrough fallback.
|
||||
if _, ok := tool["function"].(map[string]any); ok {
|
||||
out = append(out, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
// Loose fallback for flattened function schema objects.
|
||||
name := strings.TrimSpace(asString(tool["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
fn := map[string]any{"name": name}
|
||||
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
|
||||
fn["description"] = desc
|
||||
}
|
||||
if params, ok := tool["parameters"].(map[string]any); ok {
|
||||
fn["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": fn,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,348 +0,0 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, false)
|
||||
}
|
||||
|
||||
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, true)
|
||||
}
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeGeminiError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
}
|
||||
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: runtime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
runtime.finalize()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": model,
|
||||
"usageMetadata": usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
parts := make([]map[string]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
parts = append(parts, map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": tc.Input,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
text := finalText
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = finalThinking
|
||||
}
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
|
||||
type geminiStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
finalPrompt string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferContent bool
|
||||
toolNames []string
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *geminiStreamRuntime {
|
||||
return &geminiStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
s.thinking.WriteString(p.Text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferContent {
|
||||
continue
|
||||
}
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{{"text": p.Text}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferContent {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
})
|
||||
}
|
||||
|
||||
func writeGeminiError(w http.ResponseWriter, status int, message string) {
|
||||
errorStatus := "INVALID_ARGUMENT"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errorStatus = "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
errorStatus = "PERMISSION_DENIED"
|
||||
case http.StatusTooManyRequests:
|
||||
errorStatus = "RESOURCE_EXHAUSTED"
|
||||
case http.StatusNotFound:
|
||||
errorStatus = "NOT_FOUND"
|
||||
default:
|
||||
if status >= 500 {
|
||||
errorStatus = "INTERNAL"
|
||||
}
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": errorStatus,
|
||||
},
|
||||
})
|
||||
}
|
||||
28
internal/adapter/gemini/handler_errors.go
Normal file
28
internal/adapter/gemini/handler_errors.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package gemini
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeGeminiError(w http.ResponseWriter, status int, message string) {
|
||||
errorStatus := "INVALID_ARGUMENT"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errorStatus = "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
errorStatus = "PERMISSION_DENIED"
|
||||
case http.StatusTooManyRequests:
|
||||
errorStatus = "RESOURCE_EXHAUSTED"
|
||||
case http.StatusNotFound:
|
||||
errorStatus = "NOT_FOUND"
|
||||
default:
|
||||
if status >= 500 {
|
||||
errorStatus = "INTERNAL"
|
||||
}
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": errorStatus,
|
||||
},
|
||||
})
|
||||
}
|
||||
135
internal/adapter/gemini/handler_generate.go
Normal file
135
internal/adapter/gemini/handler_generate.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeGeminiError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": model,
|
||||
"usageMetadata": usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
parts := make([]map[string]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
parts = append(parts, map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": tc.Input,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
text := finalText
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = finalThinking
|
||||
}
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
32
internal/adapter/gemini/handler_routes.go
Normal file
32
internal/adapter/gemini/handler_routes.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, false)
|
||||
}
|
||||
|
||||
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, true)
|
||||
}
|
||||
175
internal/adapter/gemini/handler_stream_runtime.go
Normal file
175
internal/adapter/gemini/handler_stream_runtime.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: runtime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
runtime.finalize()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type geminiStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
finalPrompt string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferContent bool
|
||||
toolNames []string
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *geminiStreamRuntime {
|
||||
return &geminiStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
s.thinking.WriteString(p.Text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferContent {
|
||||
continue
|
||||
}
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{{"text": p.Text}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferContent {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
})
|
||||
}
|
||||
@@ -32,4 +32,3 @@ func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,386 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||
// every call-site in this file. It delegates to the shared util version.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
responsesMu sync.Mutex
|
||||
responses *responseStore
|
||||
}
|
||||
|
||||
type streamLease struct {
|
||||
Auth *auth.RequestAuth
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Get("/v1/models/{model_id}", h.GetModel)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
r.Post("/v1/responses", h.Responses)
|
||||
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
||||
r.Post("/v1/embeddings", h.Embeddings)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
|
||||
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
|
||||
model, ok := config.OpenAIModelByID(h.Store, modelID)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, model)
|
||||
}
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
if isVercelStreamReleaseRequest(r) {
|
||||
h.handleVercelStreamRelease(w, r)
|
||||
return
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) {
|
||||
h.handleVercelStreamPrepare(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
finalPrompt,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
return
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
desc = "No description available"
|
||||
}
|
||||
b, _ := json.Marshal(schema)
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
||||
}
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
old, _ := messages[i]["content"].(string)
|
||||
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
||||
return messages, names
|
||||
}
|
||||
}
|
||||
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name == "" && d.Arguments == "" {
|
||||
continue
|
||||
}
|
||||
callID, ok := ids[d.Index]
|
||||
if !ok || callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
ids[d.Index] = callID
|
||||
}
|
||||
item := map[string]any{
|
||||
"index": d.Index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
}
|
||||
fn := map[string]any{}
|
||||
if d.Name != "" {
|
||||
fn["name"] = d.Name
|
||||
}
|
||||
if d.Arguments != "" {
|
||||
fn["arguments"] = d.Arguments
|
||||
}
|
||||
if len(fn) > 0 {
|
||||
item["function"] = fn
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
callID := ""
|
||||
if ids != nil {
|
||||
callID = strings.TrimSpace(ids[i])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if ids != nil {
|
||||
ids[i] = callID
|
||||
}
|
||||
}
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": openAIErrorCode(status),
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func openAIErrorType(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
return "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_error"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable_error"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "api_error"
|
||||
}
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
func openAIErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_failed"
|
||||
case http.StatusForbidden:
|
||||
return "forbidden"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "not_found"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "internal_error"
|
||||
}
|
||||
return "invalid_request"
|
||||
}
|
||||
}
|
||||
|
||||
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||
payload[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||
return mode == "" || mode == "feature_match"
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||
return level == "" || level == "high"
|
||||
}
|
||||
156
internal/adapter/openai/handler_chat.go
Normal file
156
internal/adapter/openai/handler_chat.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
if isVercelStreamReleaseRequest(r) {
|
||||
h.handleVercelStreamRelease(w, r)
|
||||
return
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) {
|
||||
h.handleVercelStreamPrepare(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
finalPrompt,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
return
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
},
|
||||
})
|
||||
}
|
||||
56
internal/adapter/openai/handler_errors.go
Normal file
56
internal/adapter/openai/handler_errors.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package openai
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": openAIErrorCode(status),
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func openAIErrorType(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
return "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_error"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable_error"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "api_error"
|
||||
}
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
func openAIErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_failed"
|
||||
case http.StatusForbidden:
|
||||
return "forbidden"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "not_found"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "internal_error"
|
||||
}
|
||||
return "invalid_request"
|
||||
}
|
||||
}
|
||||
57
internal/adapter/openai/handler_routes.go
Normal file
57
internal/adapter/openai/handler_routes.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||
// every call-site in this package.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
responsesMu sync.Mutex
|
||||
responses *responseStore
|
||||
}
|
||||
|
||||
type streamLease struct {
|
||||
Auth *auth.RequestAuth
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Get("/v1/models/{model_id}", h.GetModel)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
r.Post("/v1/responses", h.Responses)
|
||||
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
||||
r.Post("/v1/embeddings", h.Embeddings)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
|
||||
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
|
||||
model, ok := config.OpenAIModelByID(h.Store, modelID)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, model)
|
||||
}
|
||||
116
internal/adapter/openai/handler_toolcall_format.go
Normal file
116
internal/adapter/openai/handler_toolcall_format.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
desc = "No description available"
|
||||
}
|
||||
b, _ := json.Marshal(schema)
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
||||
}
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
old, _ := messages[i]["content"].(string)
|
||||
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
||||
return messages, names
|
||||
}
|
||||
}
|
||||
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name == "" && d.Arguments == "" {
|
||||
continue
|
||||
}
|
||||
callID, ok := ids[d.Index]
|
||||
if !ok || callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
ids[d.Index] = callID
|
||||
}
|
||||
item := map[string]any{
|
||||
"index": d.Index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
}
|
||||
fn := map[string]any{}
|
||||
if d.Name != "" {
|
||||
fn["name"] = d.Name
|
||||
}
|
||||
if d.Arguments != "" {
|
||||
fn["arguments"] = d.Arguments
|
||||
}
|
||||
if len(fn) > 0 {
|
||||
item["function"] = fn
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
callID := ""
|
||||
if ids != nil {
|
||||
callID = strings.TrimSpace(ids[i])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if ids != nil {
|
||||
ids[i] = callID
|
||||
}
|
||||
}
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
25
internal/adapter/openai/handler_toolcall_policy.go
Normal file
25
internal/adapter/openai/handler_toolcall_policy.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||
payload[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||
return mode == "" || mode == "feature_match"
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||
return level == "" || level == "high"
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -170,264 +169,3 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func responsesMessagesFromRequest(req map[string]any) []any {
|
||||
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
if rawInput, ok := req["input"]; ok {
|
||||
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prependInstructionMessage(messages []any, instructions any) []any {
|
||||
sys, _ := instructions.(string)
|
||||
sys = strings.TrimSpace(sys)
|
||||
if sys == "" {
|
||||
return messages
|
||||
}
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
out = append(out, map[string]any{"role": "system", "content": sys})
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputAsMessages(input any) []any {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": v}}
|
||||
case []any:
|
||||
return normalizeResponsesInputArray(v)
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(v); msg != nil {
|
||||
return []any{msg}
|
||||
}
|
||||
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": txt}}
|
||||
}
|
||||
if content, ok := v["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesInputArray(items []any) []any {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
|
||||
fallbackParts = fallbackParts[:0]
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(x); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
if s := normalizeResponsesFallbackPart(x); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
default:
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushFallback()
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
switch itemType {
|
||||
case "message", "input_message":
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
content := m["output"]
|
||||
if content == nil {
|
||||
content = m["content"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
var fn map[string]any
|
||||
if rawFn, ok := m["function"].(map[string]any); ok {
|
||||
fn = rawFn
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var argsRaw any
|
||||
if v, ok := m["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := m["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
if argsRaw == nil && fn != nil {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := fn["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
}
|
||||
|
||||
functionPayload := map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyToolCallArguments(argsRaw),
|
||||
}
|
||||
call := map[string]any{
|
||||
"type": "function",
|
||||
"function": functionPayload,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
}
|
||||
case "input_text":
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
}
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||
}
|
||||
|
||||
func stringifyToolCallArguments(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
181
internal/adapter/openai/responses_input_items.go
Normal file
181
internal/adapter/openai/responses_input_items.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
switch itemType {
|
||||
case "message", "input_message":
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
content := m["output"]
|
||||
if content == nil {
|
||||
content = m["content"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
var fn map[string]any
|
||||
if rawFn, ok := m["function"].(map[string]any); ok {
|
||||
fn = rawFn
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var argsRaw any
|
||||
if v, ok := m["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := m["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
if argsRaw == nil && fn != nil {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := fn["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
}
|
||||
|
||||
functionPayload := map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyToolCallArguments(argsRaw),
|
||||
}
|
||||
call := map[string]any{
|
||||
"type": "function",
|
||||
"function": functionPayload,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
}
|
||||
case "input_text":
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
}
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||
}
|
||||
|
||||
func stringifyToolCallArguments(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
93
internal/adapter/openai/responses_input_normalize.go
Normal file
93
internal/adapter/openai/responses_input_normalize.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func responsesMessagesFromRequest(req map[string]any) []any {
|
||||
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
if rawInput, ok := req["input"]; ok {
|
||||
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prependInstructionMessage(messages []any, instructions any) []any {
|
||||
sys, _ := instructions.(string)
|
||||
sys = strings.TrimSpace(sys)
|
||||
if sys == "" {
|
||||
return messages
|
||||
}
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
out = append(out, map[string]any{"role": "system", "content": sys})
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputAsMessages(input any) []any {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": v}}
|
||||
case []any:
|
||||
return normalizeResponsesInputArray(v)
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(v); msg != nil {
|
||||
return []any{msg}
|
||||
}
|
||||
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": txt}}
|
||||
}
|
||||
if content, ok := v["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesInputArray(items []any) []any {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
|
||||
fallbackParts = fallbackParts[:0]
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(x); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
if s := normalizeResponsesFallbackPart(x); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
default:
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushFallback()
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,366 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type responsesStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
responseID string
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionDone map[int]bool
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
|
||||
func newResponsesStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
responseID string,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
persistResponse func(obj map[string]any),
|
||||
) *responsesStreamRuntime {
|
||||
return &responsesStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
responseID: responseID,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
toolNames: toolNames,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendCreated() {
|
||||
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendDone() {
|
||||
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
|
||||
}
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
// Compatibility fallback: some streams only emit incremental tool deltas.
|
||||
// Ensure final function_call_arguments.done is emitted at least once.
|
||||
if s.toolCallsEmitted {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitToolCallsDone(detected)
|
||||
} else {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
s.alignCompletedOutputCallIDs(obj)
|
||||
}
|
||||
if s.toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitToolCallsDone(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
|
||||
if len(calls) == 0 {
|
||||
return
|
||||
}
|
||||
sig := toolCallListSignature(calls)
|
||||
if sig != "" && s.toolCallsDoneSigs[sig] {
|
||||
return
|
||||
}
|
||||
if sig != "" {
|
||||
s.toolCallsDoneSigs[sig] = true
|
||||
}
|
||||
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
return
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDoneEvents(calls)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
if strings.TrimSpace(s.reasoningItemID) != "" {
|
||||
return s.reasoningItemID
|
||||
}
|
||||
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.reasoningItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
|
||||
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamFunctionIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
|
||||
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamToolCallIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||
for _, d := range deltas {
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.functionOutputBaseIndex() + d.Index
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.delta",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
|
||||
base := s.functionOutputBaseIndex()
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := base + idx
|
||||
if s.functionDone[outputIndex] {
|
||||
continue
|
||||
}
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
|
||||
)
|
||||
s.functionDone[outputIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
|
||||
if obj == nil || len(s.streamToolCallIDs) == 0 {
|
||||
return
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
return
|
||||
}
|
||||
indices := make([]int, 0, len(s.streamToolCallIDs))
|
||||
for idx := range s.streamToolCallIDs {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Ints(indices)
|
||||
ordered := make([]string, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
id := strings.TrimSpace(s.streamToolCallIDs[idx])
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
ordered = append(ordered, id)
|
||||
}
|
||||
if len(ordered) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
functionIdx := 0
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
switch typ {
|
||||
case "function_call":
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
case "tool_calls":
|
||||
tcArr, _ := m["tool_calls"].([]any)
|
||||
for i, raw := range tcArr {
|
||||
tc, _ := raw.(map[string]any)
|
||||
if tc == nil {
|
||||
continue
|
||||
}
|
||||
if i < len(ordered) {
|
||||
tc["id"] = ordered[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallListSignature(calls []util.ParsedToolCall) string {
|
||||
if len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
b.WriteString("|")
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(tc.Name))
|
||||
b.WriteString(":")
|
||||
args, _ := json.Marshal(tc.Input)
|
||||
b.Write(args)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
157
internal/adapter/openai/responses_stream_runtime_core.go
Normal file
157
internal/adapter/openai/responses_stream_runtime_core.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type responsesStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
responseID string
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionDone map[int]bool
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
|
||||
func newResponsesStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
responseID string,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
persistResponse func(obj map[string]any),
|
||||
) *responsesStreamRuntime {
|
||||
return &responsesStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
responseID: responseID,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
toolNames: toolNames,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
|
||||
}
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
// Compatibility fallback: some streams only emit incremental tool deltas.
|
||||
// Ensure final function_call_arguments.done is emitted at least once.
|
||||
if s.toolCallsEmitted {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitToolCallsDone(detected)
|
||||
} else {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
s.alignCompletedOutputCallIDs(obj)
|
||||
}
|
||||
if s.toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
52
internal/adapter/openai/responses_stream_runtime_events.go
Normal file
52
internal/adapter/openai/responses_stream_runtime_events.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendCreated() {
|
||||
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendDone() {
|
||||
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitToolCallsDone(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
172
internal/adapter/openai/responses_stream_runtime_toolcalls.go
Normal file
172
internal/adapter/openai/responses_stream_runtime_toolcalls.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/util"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
|
||||
if len(calls) == 0 {
|
||||
return
|
||||
}
|
||||
sig := toolCallListSignature(calls)
|
||||
if sig != "" && s.toolCallsDoneSigs[sig] {
|
||||
return
|
||||
}
|
||||
if sig != "" {
|
||||
s.toolCallsDoneSigs[sig] = true
|
||||
}
|
||||
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
return
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDoneEvents(calls)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
if strings.TrimSpace(s.reasoningItemID) != "" {
|
||||
return s.reasoningItemID
|
||||
}
|
||||
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.reasoningItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
|
||||
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamFunctionIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
|
||||
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamToolCallIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||
for _, d := range deltas {
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.functionOutputBaseIndex() + d.Index
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.delta",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
|
||||
base := s.functionOutputBaseIndex()
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := base + idx
|
||||
if s.functionDone[outputIndex] {
|
||||
continue
|
||||
}
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
|
||||
)
|
||||
s.functionDone[outputIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
|
||||
if obj == nil || len(s.streamToolCallIDs) == 0 {
|
||||
return
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
return
|
||||
}
|
||||
indices := make([]int, 0, len(s.streamToolCallIDs))
|
||||
for idx := range s.streamToolCallIDs {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Ints(indices)
|
||||
ordered := make([]string, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
id := strings.TrimSpace(s.streamToolCallIDs[idx])
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
ordered = append(ordered, id)
|
||||
}
|
||||
if len(ordered) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
functionIdx := 0
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
switch typ {
|
||||
case "function_call":
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
case "tool_calls":
|
||||
tcArr, _ := m["tool_calls"].([]any)
|
||||
for i, raw := range tcArr {
|
||||
tc, _ := raw.(map[string]any)
|
||||
if tc == nil {
|
||||
continue
|
||||
}
|
||||
if i < len(ordered) {
|
||||
tc["id"] = ordered[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallListSignature(calls []util.ParsedToolCall) string {
|
||||
if len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
b.WriteString("|")
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(tc.Name))
|
||||
b.WriteString(":")
|
||||
args, _ := json.Marshal(tc.Input)
|
||||
b.Write(args)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -1,713 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
ToolCallDeltas []toolCallDelta
|
||||
}
|
||||
|
||||
type toolCallDelta struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
s.toolNameSent = false
|
||||
s.toolName = ""
|
||||
s.toolArgsStart = -1
|
||||
s.toolArgsSent = -1
|
||||
s.toolArgsString = false
|
||||
s.toolArgsDone = false
|
||||
}
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if chunk != "" {
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
if state.pending.Len() > 0 {
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
if state.capture.Len() > toolSieveCaptureLimit {
|
||||
content := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pending := state.pending.String()
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
state.resetIncrementalToolState()
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
state.noteText(safe)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
state.noteText(consumedPrefix)
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
state.noteText(consumedSuffix)
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
// If suspicious content starts at position 0, keep holding until we can
|
||||
// parse a complete tool JSON block or reach stream flush.
|
||||
return "", s
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
offset := 0
|
||||
for {
|
||||
keyRel := strings.Index(lower[offset:], "tool_calls")
|
||||
if keyRel < 0 {
|
||||
return -1
|
||||
}
|
||||
keyIdx := offset + keyRel
|
||||
start := strings.LastIndex(s[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
if !insideCodeFence(s[:start]) {
|
||||
return start
|
||||
}
|
||||
offset = keyIdx + len("tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if state.disableDeltas {
|
||||
return nil
|
||||
}
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return nil
|
||||
}
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||
if hasMultiple {
|
||||
state.disableDeltas = true
|
||||
return nil
|
||||
}
|
||||
if !certainSingle {
|
||||
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
deltas := make([]toolCallDelta, 0, 2)
|
||||
if state.toolName == "" {
|
||||
name, ok := extractToolCallName(captured, callStart)
|
||||
if !ok || name == "" {
|
||||
return nil
|
||||
}
|
||||
state.toolName = name
|
||||
}
|
||||
if state.toolArgsStart < 0 {
|
||||
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||
if ok {
|
||||
state.toolArgsString = stringMode
|
||||
if stringMode {
|
||||
state.toolArgsStart = argsStart + 1
|
||||
} else {
|
||||
state.toolArgsStart = argsStart
|
||||
}
|
||||
state.toolArgsSent = state.toolArgsStart
|
||||
}
|
||||
}
|
||||
if !state.toolNameSent {
|
||||
if state.toolArgsStart < 0 {
|
||||
return nil
|
||||
}
|
||||
state.toolNameSent = true
|
||||
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||
}
|
||||
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||
return deltas
|
||||
}
|
||||
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||
if !ok {
|
||||
return deltas
|
||||
}
|
||||
if end > state.toolArgsSent {
|
||||
deltas = append(deltas, toolCallDelta{
|
||||
Index: 0,
|
||||
Arguments: captured[state.toolArgsSent:end],
|
||||
})
|
||||
state.toolArgsSent = end
|
||||
}
|
||||
if complete {
|
||||
state.toolArgsDone = true
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return false, false
|
||||
}
|
||||
count := 0
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for ; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
if depth == 0 {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == ',' && depth == 0 {
|
||||
// top-level separator means at least one more tool call exists
|
||||
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||
return false, true
|
||||
}
|
||||
if ch == ']' && depth == 0 {
|
||||
return count == 1, false
|
||||
}
|
||||
}
|
||||
// array not closed yet: still uncertain whether more calls will appear
|
||||
return false, false
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||
i := keyIdx + len("tool_calls")
|
||||
for i < len(text) && text[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return -1, false
|
||||
}
|
||||
i = skipSpaces(text, i+1)
|
||||
if i >= len(text) || text[i] != '[' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return "", false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||
if !ok {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return -1, false, false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||
if !ok {
|
||||
return -1, false, false
|
||||
}
|
||||
}
|
||||
if valueStart >= len(text) {
|
||||
return -1, false, false
|
||||
}
|
||||
ch := text[valueStart]
|
||||
if ch == '{' || ch == '[' {
|
||||
return valueStart, false, true
|
||||
}
|
||||
if ch == '"' {
|
||||
return valueStart, true, true
|
||||
}
|
||||
return -1, false, false
|
||||
}
|
||||
|
||||
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||
if start < 0 || start > len(text) {
|
||||
return 0, false, false
|
||||
}
|
||||
if stringMode {
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return i, true, true
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
if start >= len(text) {
|
||||
return start, false, false
|
||||
}
|
||||
if text[start] != '{' && text[start] != '[' {
|
||||
return 0, false, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' || ch == '[' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' || ch == ']' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1, true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
|
||||
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||
return 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := objStart; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
if depth == 1 {
|
||||
key, end, ok := parseJSONStringLiteral(text, i)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
j := skipSpaces(text, end)
|
||||
if j >= len(text) || text[j] != ':' {
|
||||
i = end - 1
|
||||
continue
|
||||
}
|
||||
j = skipSpaces(text, j+1)
|
||||
if j >= len(text) {
|
||||
return 0, false
|
||||
}
|
||||
if containsKey(keys, key) {
|
||||
return j, true
|
||||
}
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return valueStart, true
|
||||
}
|
||||
|
||||
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
var b strings.Builder
|
||||
escaped := false
|
||||
for i := start + 1; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
b.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return b.String(), i + 1, true
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func containsKey(keys []string, value string) bool {
|
||||
for _, k := range keys {
|
||||
if k == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func skipSpaces(text string, i int) int {
|
||||
for i < len(text) {
|
||||
switch text[i] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
i++
|
||||
default:
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func (s *toolStreamSieveState) noteText(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||
}
|
||||
|
||||
func appendTail(prev, next string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
combined := prev + next
|
||||
if len(combined) <= max {
|
||||
return combined
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
208
internal/adapter/openai/tool_sieve_core.go
Normal file
208
internal/adapter/openai/tool_sieve_core.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if chunk != "" {
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
if state.pending.Len() > 0 {
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
if state.capture.Len() > toolSieveCaptureLimit {
|
||||
content := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pending := state.pending.String()
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
state.resetIncrementalToolState()
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
state.noteText(safe)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
state.noteText(consumedPrefix)
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
state.noteText(consumedSuffix)
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
// If suspicious content starts at position 0, keep holding until we can
|
||||
// parse a complete tool JSON block or reach stream flush.
|
||||
return "", s
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
offset := 0
|
||||
for {
|
||||
keyRel := strings.Index(lower[offset:], "tool_calls")
|
||||
if keyRel < 0 {
|
||||
return -1
|
||||
}
|
||||
keyIdx := offset + keyRel
|
||||
start := strings.LastIndex(s[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
if !insideCodeFence(s[:start]) {
|
||||
return start
|
||||
}
|
||||
offset = keyIdx + len("tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
291
internal/adapter/openai/tool_sieve_incremental.go
Normal file
291
internal/adapter/openai/tool_sieve_incremental.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if state.disableDeltas {
|
||||
return nil
|
||||
}
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return nil
|
||||
}
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||
if hasMultiple {
|
||||
state.disableDeltas = true
|
||||
return nil
|
||||
}
|
||||
if !certainSingle {
|
||||
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
deltas := make([]toolCallDelta, 0, 2)
|
||||
if state.toolName == "" {
|
||||
name, ok := extractToolCallName(captured, callStart)
|
||||
if !ok || name == "" {
|
||||
return nil
|
||||
}
|
||||
state.toolName = name
|
||||
}
|
||||
if state.toolArgsStart < 0 {
|
||||
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||
if ok {
|
||||
state.toolArgsString = stringMode
|
||||
if stringMode {
|
||||
state.toolArgsStart = argsStart + 1
|
||||
} else {
|
||||
state.toolArgsStart = argsStart
|
||||
}
|
||||
state.toolArgsSent = state.toolArgsStart
|
||||
}
|
||||
}
|
||||
if !state.toolNameSent {
|
||||
if state.toolArgsStart < 0 {
|
||||
return nil
|
||||
}
|
||||
state.toolNameSent = true
|
||||
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||
}
|
||||
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||
return deltas
|
||||
}
|
||||
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||
if !ok {
|
||||
return deltas
|
||||
}
|
||||
if end > state.toolArgsSent {
|
||||
deltas = append(deltas, toolCallDelta{
|
||||
Index: 0,
|
||||
Arguments: captured[state.toolArgsSent:end],
|
||||
})
|
||||
state.toolArgsSent = end
|
||||
}
|
||||
if complete {
|
||||
state.toolArgsDone = true
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return false, false
|
||||
}
|
||||
count := 0
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for ; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
if depth == 0 {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == ',' && depth == 0 {
|
||||
// top-level separator means at least one more tool call exists
|
||||
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||
return false, true
|
||||
}
|
||||
if ch == ']' && depth == 0 {
|
||||
return count == 1, false
|
||||
}
|
||||
}
|
||||
// array not closed yet: still uncertain whether more calls will appear
|
||||
return false, false
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||
i := keyIdx + len("tool_calls")
|
||||
for i < len(text) && text[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return -1, false
|
||||
}
|
||||
i = skipSpaces(text, i+1)
|
||||
if i >= len(text) || text[i] != '[' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return "", false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||
if !ok {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return -1, false, false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||
if !ok {
|
||||
return -1, false, false
|
||||
}
|
||||
}
|
||||
if valueStart >= len(text) {
|
||||
return -1, false, false
|
||||
}
|
||||
ch := text[valueStart]
|
||||
if ch == '{' || ch == '[' {
|
||||
return valueStart, false, true
|
||||
}
|
||||
if ch == '"' {
|
||||
return valueStart, true, true
|
||||
}
|
||||
return -1, false, false
|
||||
}
|
||||
|
||||
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||
if start < 0 || start > len(text) {
|
||||
return 0, false, false
|
||||
}
|
||||
if stringMode {
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return i, true, true
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
if start >= len(text) {
|
||||
return start, false, false
|
||||
}
|
||||
if text[start] != '{' && text[start] != '[' {
|
||||
return 0, false, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' || ch == '[' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' || ch == ']' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1, true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
|
||||
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return valueStart, true
|
||||
}
|
||||
152
internal/adapter/openai/tool_sieve_jsonscan.go
Normal file
152
internal/adapter/openai/tool_sieve_jsonscan.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||
return 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := objStart; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
if depth == 1 {
|
||||
key, end, ok := parseJSONStringLiteral(text, i)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
j := skipSpaces(text, end)
|
||||
if j >= len(text) || text[j] != ':' {
|
||||
i = end - 1
|
||||
continue
|
||||
}
|
||||
j = skipSpaces(text, j+1)
|
||||
if j >= len(text) {
|
||||
return 0, false
|
||||
}
|
||||
if containsKey(keys, key) {
|
||||
return j, true
|
||||
}
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
var b strings.Builder
|
||||
escaped := false
|
||||
for i := start + 1; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
b.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return b.String(), i + 1, true
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func containsKey(keys []string, value string) bool {
|
||||
for _, k := range keys {
|
||||
if k == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func skipSpaces(text string, i int) int {
|
||||
for i < len(text) {
|
||||
switch text[i] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
i++
|
||||
default:
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
75
internal/adapter/openai/tool_sieve_state.go
Normal file
75
internal/adapter/openai/tool_sieve_state.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
ToolCallDeltas []toolCallDelta
|
||||
}
|
||||
|
||||
type toolCallDelta struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
s.toolNameSent = false
|
||||
s.toolName = ""
|
||||
s.toolArgsStart = -1
|
||||
s.toolArgsSent = -1
|
||||
s.toolArgsString = false
|
||||
s.toolArgsDone = false
|
||||
}
|
||||
|
||||
func (s *toolStreamSieveState) noteText(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||
}
|
||||
|
||||
func appendTail(prev, next string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
combined := prev + next
|
||||
if len(combined) <= max {
|
||||
return combined
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
114
internal/admin/handler_accounts_crud.go
Normal file
114
internal/admin/handler_accounts_crud.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
page := intFromQuery(r, "page", 1)
|
||||
pageSize := intFromQuery(r, "page_size", 10)
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
total := len(accounts)
|
||||
reverseAccounts(accounts)
|
||||
totalPages := 1
|
||||
if total > 0 {
|
||||
totalPages = (total + pageSize - 1) / pageSize
|
||||
}
|
||||
start := (page - 1) * pageSize
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
end := start + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
}
|
||||
|
||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
acc := toAccount(req)
|
||||
if acc.Identifier() == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, a := range c.Accounts {
|
||||
if acc.Email != "" && a.Email == acc.Email {
|
||||
return fmt.Errorf("邮箱已存在")
|
||||
}
|
||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
||||
return fmt.Errorf("手机号已存在")
|
||||
}
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, a := range c.Accounts {
|
||||
if accountMatchesIdentifier(a, identifier) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("账号不存在")
|
||||
}
|
||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
7
internal/admin/handler_accounts_queue.go
Normal file
7
internal/admin/handler_accounts_queue.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package admin
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, h.Pool.Status())
|
||||
}
|
||||
@@ -11,119 +11,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
page := intFromQuery(r, "page", 1)
|
||||
pageSize := intFromQuery(r, "page_size", 10)
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
total := len(accounts)
|
||||
reverseAccounts(accounts)
|
||||
totalPages := 1
|
||||
if total > 0 {
|
||||
totalPages = (total + pageSize - 1) / pageSize
|
||||
}
|
||||
start := (page - 1) * pageSize
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
end := start + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
}
|
||||
|
||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
acc := toAccount(req)
|
||||
if acc.Identifier() == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, a := range c.Accounts {
|
||||
if acc.Email != "" && a.Email == acc.Email {
|
||||
return fmt.Errorf("邮箱已存在")
|
||||
}
|
||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
||||
return fmt.Errorf("手机号已存在")
|
||||
}
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, a := range c.Accounts {
|
||||
if accountMatchesIdentifier(a, identifier) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("账号不存在")
|
||||
}
|
||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, h.Pool.Status())
|
||||
}
|
||||
|
||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
@@ -1,393 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"claude_mapping": func() map[string]string {
|
||||
if len(snap.ClaudeMapping) > 0 {
|
||||
return snap.ClaudeMapping
|
||||
}
|
||||
return snap.ClaudeModelMap
|
||||
}(),
|
||||
}
|
||||
accounts := make([]map[string]any, 0, len(snap.Accounts))
|
||||
for _, acc := range snap.Accounts {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": strings.TrimSpace(acc.Password) != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
})
|
||||
}
|
||||
safe["accounts"] = accounts
|
||||
writeJSON(w, http.StatusOK, safe)
|
||||
}
|
||||
|
||||
func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
old := h.Store.Snapshot()
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
c.Keys = keys
|
||||
}
|
||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]config.Account{}
|
||||
for _, a := range old.Accounts {
|
||||
existing[a.Identifier()] = a
|
||||
}
|
||||
accounts := make([]config.Account, 0, len(accountsRaw))
|
||||
for _, item := range accountsRaw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if prev, ok := existing[id]; ok {
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
acc.Password = prev.Password
|
||||
}
|
||||
if strings.TrimSpace(acc.Token) == "" {
|
||||
acc.Token = prev.Token
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
c.Accounts = accounts
|
||||
}
|
||||
if m, ok := req["claude_mapping"].(map[string]any); ok {
|
||||
newMap := map[string]string{}
|
||||
for k, v := range m {
|
||||
newMap[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
c.ClaudeMapping = newMap
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
|
||||
}
|
||||
|
||||
func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key, _ := req["key"].(string)
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, k := range c.Keys {
|
||||
if k == key {
|
||||
return fmt.Errorf("Key 已存在")
|
||||
}
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, k := range c.Keys {
|
||||
if k == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("Key 不存在")
|
||||
}
|
||||
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
|
||||
return
|
||||
}
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := req["keys"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, k := range c.Keys {
|
||||
existing[k] = true
|
||||
}
|
||||
for _, k := range keys {
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", k))
|
||||
if key == "" || existing[key] {
|
||||
continue
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
existing[key] = true
|
||||
importedKeys++
|
||||
}
|
||||
}
|
||||
if accounts, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, a := range c.Accounts {
|
||||
existing[a.Identifier()] = true
|
||||
}
|
||||
for _, item := range accounts {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if id == "" || existing[id] {
|
||||
continue
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
existing[id] = true
|
||||
importedAccounts++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
|
||||
}
|
||||
|
||||
func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
h.configExport(w, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
jsonStr, b64, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"config": snap,
|
||||
"json": jsonStr,
|
||||
"base64": b64,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
|
||||
mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
|
||||
if mode == "" {
|
||||
mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "merge"
|
||||
}
|
||||
if mode != "merge" && mode != "replace" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
|
||||
return
|
||||
}
|
||||
|
||||
payload := req
|
||||
if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
|
||||
payload = raw
|
||||
}
|
||||
rawJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
|
||||
return
|
||||
}
|
||||
var incoming config.Config
|
||||
if err := json.Unmarshal(rawJSON, &incoming); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err = h.Store.Update(func(c *config.Config) error {
|
||||
next := c.Clone()
|
||||
if mode == "replace" {
|
||||
next = incoming.Clone()
|
||||
next.VercelSyncHash = c.VercelSyncHash
|
||||
next.VercelSyncTime = c.VercelSyncTime
|
||||
importedKeys = len(next.Keys)
|
||||
importedAccounts = len(next.Accounts)
|
||||
} else {
|
||||
existingKeys := map[string]struct{}{}
|
||||
for _, k := range next.Keys {
|
||||
existingKeys[k] = struct{}{}
|
||||
}
|
||||
for _, k := range incoming.Keys {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
existingKeys[key] = struct{}{}
|
||||
next.Keys = append(next.Keys, key)
|
||||
importedKeys++
|
||||
}
|
||||
|
||||
existingAccounts := map[string]struct{}{}
|
||||
for _, acc := range next.Accounts {
|
||||
existingAccounts[acc.Identifier()] = struct{}{}
|
||||
}
|
||||
for _, acc := range incoming.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingAccounts[id]; ok {
|
||||
continue
|
||||
}
|
||||
existingAccounts[id] = struct{}{}
|
||||
next.Accounts = append(next.Accounts, acc)
|
||||
importedAccounts++
|
||||
}
|
||||
|
||||
if len(incoming.ClaudeMapping) > 0 {
|
||||
if next.ClaudeMapping == nil {
|
||||
next.ClaudeMapping = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ClaudeMapping {
|
||||
next.ClaudeMapping[k] = v
|
||||
}
|
||||
}
|
||||
if len(incoming.ClaudeModelMap) > 0 {
|
||||
if next.ClaudeModelMap == nil {
|
||||
next.ClaudeModelMap = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ClaudeModelMap {
|
||||
next.ClaudeModelMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if len(incoming.ModelAliases) > 0 {
|
||||
if next.ModelAliases == nil {
|
||||
next.ModelAliases = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ModelAliases {
|
||||
next.ModelAliases[k] = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
|
||||
next.Toolcall.Mode = incoming.Toolcall.Mode
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
|
||||
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
|
||||
}
|
||||
if incoming.Responses.StoreTTLSeconds > 0 {
|
||||
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
|
||||
}
|
||||
if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
|
||||
next.Embeddings.Provider = incoming.Embeddings.Provider
|
||||
}
|
||||
if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
|
||||
next.Admin.PasswordHash = incoming.Admin.PasswordHash
|
||||
}
|
||||
if incoming.Admin.JWTExpireHours > 0 {
|
||||
next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
|
||||
}
|
||||
if incoming.Admin.JWTValidAfterUnix > 0 {
|
||||
next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
|
||||
}
|
||||
if incoming.Runtime.AccountMaxInflight > 0 {
|
||||
next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
|
||||
}
|
||||
if incoming.Runtime.AccountMaxQueue > 0 {
|
||||
next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
|
||||
}
|
||||
if incoming.Runtime.GlobalMaxInflight > 0 {
|
||||
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
|
||||
normalizeSettingsConfig(&next)
|
||||
if err := validateSettingsConfig(next); err != nil {
|
||||
return newRequestError(err.Error())
|
||||
}
|
||||
|
||||
*c = next
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if detail, ok := requestErrorDetail(err); ok {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"mode": mode,
|
||||
"imported_keys": importedKeys,
|
||||
"imported_accounts": importedAccounts,
|
||||
"message": "config imported",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) computeSyncHash() string {
|
||||
snap := h.Store.Snapshot().Clone()
|
||||
snap.VercelSyncHash = ""
|
||||
snap.VercelSyncTime = 0
|
||||
b, _ := json.Marshal(snap)
|
||||
sum := md5.Sum(b)
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
182
internal/admin/handler_config_import.go
Normal file
182
internal/admin/handler_config_import.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
|
||||
mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
|
||||
if mode == "" {
|
||||
mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "merge"
|
||||
}
|
||||
if mode != "merge" && mode != "replace" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
|
||||
return
|
||||
}
|
||||
|
||||
payload := req
|
||||
if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
|
||||
payload = raw
|
||||
}
|
||||
rawJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
|
||||
return
|
||||
}
|
||||
var incoming config.Config
|
||||
if err := json.Unmarshal(rawJSON, &incoming); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err = h.Store.Update(func(c *config.Config) error {
|
||||
next := c.Clone()
|
||||
if mode == "replace" {
|
||||
next = incoming.Clone()
|
||||
next.VercelSyncHash = c.VercelSyncHash
|
||||
next.VercelSyncTime = c.VercelSyncTime
|
||||
importedKeys = len(next.Keys)
|
||||
importedAccounts = len(next.Accounts)
|
||||
} else {
|
||||
existingKeys := map[string]struct{}{}
|
||||
for _, k := range next.Keys {
|
||||
existingKeys[k] = struct{}{}
|
||||
}
|
||||
for _, k := range incoming.Keys {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
existingKeys[key] = struct{}{}
|
||||
next.Keys = append(next.Keys, key)
|
||||
importedKeys++
|
||||
}
|
||||
|
||||
existingAccounts := map[string]struct{}{}
|
||||
for _, acc := range next.Accounts {
|
||||
existingAccounts[acc.Identifier()] = struct{}{}
|
||||
}
|
||||
for _, acc := range incoming.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingAccounts[id]; ok {
|
||||
continue
|
||||
}
|
||||
existingAccounts[id] = struct{}{}
|
||||
next.Accounts = append(next.Accounts, acc)
|
||||
importedAccounts++
|
||||
}
|
||||
|
||||
if len(incoming.ClaudeMapping) > 0 {
|
||||
if next.ClaudeMapping == nil {
|
||||
next.ClaudeMapping = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ClaudeMapping {
|
||||
next.ClaudeMapping[k] = v
|
||||
}
|
||||
}
|
||||
if len(incoming.ClaudeModelMap) > 0 {
|
||||
if next.ClaudeModelMap == nil {
|
||||
next.ClaudeModelMap = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ClaudeModelMap {
|
||||
next.ClaudeModelMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if len(incoming.ModelAliases) > 0 {
|
||||
if next.ModelAliases == nil {
|
||||
next.ModelAliases = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ModelAliases {
|
||||
next.ModelAliases[k] = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
|
||||
next.Toolcall.Mode = incoming.Toolcall.Mode
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
|
||||
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
|
||||
}
|
||||
if incoming.Responses.StoreTTLSeconds > 0 {
|
||||
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
|
||||
}
|
||||
if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
|
||||
next.Embeddings.Provider = incoming.Embeddings.Provider
|
||||
}
|
||||
if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
|
||||
next.Admin.PasswordHash = incoming.Admin.PasswordHash
|
||||
}
|
||||
if incoming.Admin.JWTExpireHours > 0 {
|
||||
next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
|
||||
}
|
||||
if incoming.Admin.JWTValidAfterUnix > 0 {
|
||||
next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
|
||||
}
|
||||
if incoming.Runtime.AccountMaxInflight > 0 {
|
||||
next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
|
||||
}
|
||||
if incoming.Runtime.AccountMaxQueue > 0 {
|
||||
next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
|
||||
}
|
||||
if incoming.Runtime.GlobalMaxInflight > 0 {
|
||||
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
|
||||
normalizeSettingsConfig(&next)
|
||||
if err := validateSettingsConfig(next); err != nil {
|
||||
return newRequestError(err.Error())
|
||||
}
|
||||
|
||||
*c = next
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if detail, ok := requestErrorDetail(err); ok {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"mode": mode,
|
||||
"imported_keys": importedKeys,
|
||||
"imported_accounts": importedAccounts,
|
||||
"message": "config imported",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) computeSyncHash() string {
|
||||
snap := h.Store.Snapshot().Clone()
|
||||
snap.VercelSyncHash = ""
|
||||
snap.VercelSyncTime = 0
|
||||
b, _ := json.Marshal(snap)
|
||||
sum := md5.Sum(b)
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
61
internal/admin/handler_config_read.go
Normal file
61
internal/admin/handler_config_read.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"claude_mapping": func() map[string]string {
|
||||
if len(snap.ClaudeMapping) > 0 {
|
||||
return snap.ClaudeMapping
|
||||
}
|
||||
return snap.ClaudeModelMap
|
||||
}(),
|
||||
}
|
||||
accounts := make([]map[string]any, 0, len(snap.Accounts))
|
||||
for _, acc := range snap.Accounts {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": strings.TrimSpace(acc.Password) != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
})
|
||||
}
|
||||
safe["accounts"] = accounts
|
||||
writeJSON(w, http.StatusOK, safe)
|
||||
}
|
||||
|
||||
func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
h.configExport(w, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
jsonStr, b64, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"config": snap,
|
||||
"json": jsonStr,
|
||||
"base64": b64,
|
||||
})
|
||||
}
|
||||
166
internal/admin/handler_config_write.go
Normal file
166
internal/admin/handler_config_write.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
old := h.Store.Snapshot()
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := toStringSlice(req["keys"]); ok {
|
||||
c.Keys = keys
|
||||
}
|
||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]config.Account{}
|
||||
for _, a := range old.Accounts {
|
||||
existing[a.Identifier()] = a
|
||||
}
|
||||
accounts := make([]config.Account, 0, len(accountsRaw))
|
||||
for _, item := range accountsRaw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if prev, ok := existing[id]; ok {
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
acc.Password = prev.Password
|
||||
}
|
||||
if strings.TrimSpace(acc.Token) == "" {
|
||||
acc.Token = prev.Token
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
c.Accounts = accounts
|
||||
}
|
||||
if m, ok := req["claude_mapping"].(map[string]any); ok {
|
||||
newMap := map[string]string{}
|
||||
for k, v := range m {
|
||||
newMap[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
c.ClaudeMapping = newMap
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
|
||||
}
|
||||
|
||||
func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key, _ := req["key"].(string)
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, k := range c.Keys {
|
||||
if k == key {
|
||||
return fmt.Errorf("Key 已存在")
|
||||
}
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
|
||||
key := chi.URLParam(r, "key")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, k := range c.Keys {
|
||||
if k == key {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("Key 不存在")
|
||||
}
|
||||
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
|
||||
}
|
||||
|
||||
func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
|
||||
return
|
||||
}
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
if keys, ok := req["keys"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, k := range c.Keys {
|
||||
existing[k] = true
|
||||
}
|
||||
for _, k := range keys {
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", k))
|
||||
if key == "" || existing[key] {
|
||||
continue
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
existing[key] = true
|
||||
importedKeys++
|
||||
}
|
||||
}
|
||||
if accounts, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, a := range c.Accounts {
|
||||
existing[a.Identifier()] = true
|
||||
}
|
||||
for _, item := range accounts {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if id == "" || existing[id] {
|
||||
continue
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
existing[id] = true
|
||||
importedAccounts++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
|
||||
needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"admin": map[string]any{
|
||||
"has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
|
||||
"jwt_expire_hours": h.Store.AdminJWTExpireHours(),
|
||||
"jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
|
||||
"default_password_warning": authn.UsingDefaultAdminKey(h.Store),
|
||||
},
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
|
||||
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
},
|
||||
"toolcall": snap.Toolcall,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"needs_vercel_sync": needsSync,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
if runtimeCfg != nil {
|
||||
if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.Store.Update(func(c *config.Config) error {
|
||||
if adminCfg != nil {
|
||||
if adminCfg.JWTExpireHours > 0 {
|
||||
c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
|
||||
}
|
||||
}
|
||||
if runtimeCfg != nil {
|
||||
if runtimeCfg.AccountMaxInflight > 0 {
|
||||
c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
|
||||
}
|
||||
if runtimeCfg.AccountMaxQueue > 0 {
|
||||
c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
|
||||
}
|
||||
if runtimeCfg.GlobalMaxInflight > 0 {
|
||||
c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
if toolcallCfg != nil {
|
||||
if strings.TrimSpace(toolcallCfg.Mode) != "" {
|
||||
c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
|
||||
}
|
||||
if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
|
||||
c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
|
||||
}
|
||||
}
|
||||
if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
|
||||
c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
|
||||
}
|
||||
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
||||
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
}
|
||||
if aliasMap != nil {
|
||||
c.ModelAliases = aliasMap
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.applyRuntimeSettings()
|
||||
needsSync := config.IsVercel() || h.Store.IsEnvBacked()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": "settings updated and hot reloaded",
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"needs_vercel_sync": needsSync,
|
||||
"manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
|
||||
})
|
||||
}
|
||||
|
||||
func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
|
||||
merged := current
|
||||
if incoming != nil {
|
||||
if incoming.AccountMaxInflight > 0 {
|
||||
merged.AccountMaxInflight = incoming.AccountMaxInflight
|
||||
}
|
||||
if incoming.AccountMaxQueue > 0 {
|
||||
merged.AccountMaxQueue = incoming.AccountMaxQueue
|
||||
}
|
||||
if incoming.GlobalMaxInflight > 0 {
|
||||
merged.GlobalMaxInflight = incoming.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
return validateRuntimeSettings(merged)
|
||||
}
|
||||
|
||||
func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
newPassword := strings.TrimSpace(fieldString(req, "new_password"))
|
||||
if newPassword == "" {
|
||||
newPassword = strings.TrimSpace(fieldString(req, "password"))
|
||||
}
|
||||
if len(newPassword) < 4 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
hash := authn.HashAdminPassword(newPassword)
|
||||
if err := h.Store.Update(func(c *config.Config) error {
|
||||
c.Admin.PasswordHash = hash
|
||||
c.Admin.JWTValidAfterUnix = now
|
||||
return nil
|
||||
}); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": "password updated",
|
||||
"force_relogin": true,
|
||||
"jwt_valid_after_unix": now,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) applyRuntimeSettings() {
|
||||
if h == nil || h.Store == nil || h.Pool == nil {
|
||||
return
|
||||
}
|
||||
accountCount := len(h.Store.Accounts())
|
||||
maxPer := h.Store.RuntimeAccountMaxInflight()
|
||||
recommended := defaultRuntimeRecommended(accountCount, maxPer)
|
||||
maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
|
||||
global := h.Store.RuntimeGlobalMaxInflight(recommended)
|
||||
h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
|
||||
}
|
||||
|
||||
func defaultRuntimeRecommended(accountCount, maxPer int) int {
|
||||
if maxPer <= 0 {
|
||||
maxPer = 1
|
||||
}
|
||||
if accountCount <= 0 {
|
||||
return maxPer
|
||||
}
|
||||
return accountCount * maxPer
|
||||
}
|
||||
|
||||
func settingsClaudeMapping(c config.Config) map[string]string {
|
||||
if len(c.ClaudeMapping) > 0 {
|
||||
return c.ClaudeMapping
|
||||
}
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
return c.ClaudeModelMap
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
toolcallCfg *config.ToolcallConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
cfg := &config.AdminConfig{}
|
||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 720 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||
}
|
||||
cfg.JWTExpireHours = n
|
||||
}
|
||||
adminCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["runtime"].(map[string]any); ok {
|
||||
cfg := &config.RuntimeConfig{}
|
||||
if v, exists := raw["account_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 256 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||
}
|
||||
cfg.AccountMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["account_max_queue"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||
}
|
||||
cfg.AccountMaxQueue = n
|
||||
}
|
||||
if v, exists := raw["global_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["toolcall"].(map[string]any); ok {
|
||||
cfg := &config.ToolcallConfig{}
|
||||
if v, exists := raw["mode"]; exists {
|
||||
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch mode {
|
||||
case "feature_match", "off":
|
||||
cfg.Mode = mode
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if v, exists := raw["early_emit_confidence"]; exists {
|
||||
level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch level {
|
||||
case "high", "low", "off":
|
||||
cfg.EarlyEmitConfidence = level
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
toolcallCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["responses"].(map[string]any); ok {
|
||||
cfg := &config.ResponsesConfig{}
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 30 || n > 86400 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
}
|
||||
cfg.StoreTTLSeconds = n
|
||||
}
|
||||
respCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["embeddings"].(map[string]any); ok {
|
||||
cfg := &config.EmbeddingsConfig{}
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if p == "" {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
embCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["claude_mapping"].(map[string]any); ok {
|
||||
claudeMap = map[string]string{}
|
||||
for k, v := range raw {
|
||||
key := strings.TrimSpace(k)
|
||||
val := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
claudeMap[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := req["model_aliases"].(map[string]any); ok {
|
||||
aliasMap = map[string]string{}
|
||||
for k, v := range raw {
|
||||
key := strings.TrimSpace(k)
|
||||
val := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
aliasMap[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
134
internal/admin/handler_settings_parse.go
Normal file
134
internal/admin/handler_settings_parse.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
toolcallCfg *config.ToolcallConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
cfg := &config.AdminConfig{}
|
||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 720 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||
}
|
||||
cfg.JWTExpireHours = n
|
||||
}
|
||||
adminCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["runtime"].(map[string]any); ok {
|
||||
cfg := &config.RuntimeConfig{}
|
||||
if v, exists := raw["account_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 256 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||
}
|
||||
cfg.AccountMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["account_max_queue"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||
}
|
||||
cfg.AccountMaxQueue = n
|
||||
}
|
||||
if v, exists := raw["global_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["toolcall"].(map[string]any); ok {
|
||||
cfg := &config.ToolcallConfig{}
|
||||
if v, exists := raw["mode"]; exists {
|
||||
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch mode {
|
||||
case "feature_match", "off":
|
||||
cfg.Mode = mode
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if v, exists := raw["early_emit_confidence"]; exists {
|
||||
level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch level {
|
||||
case "high", "low", "off":
|
||||
cfg.EarlyEmitConfidence = level
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
toolcallCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["responses"].(map[string]any); ok {
|
||||
cfg := &config.ResponsesConfig{}
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 30 || n > 86400 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
}
|
||||
cfg.StoreTTLSeconds = n
|
||||
}
|
||||
respCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["embeddings"].(map[string]any); ok {
|
||||
cfg := &config.EmbeddingsConfig{}
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if p == "" {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
embCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["claude_mapping"].(map[string]any); ok {
|
||||
claudeMap = map[string]string{}
|
||||
for k, v := range raw {
|
||||
key := strings.TrimSpace(k)
|
||||
val := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
claudeMap[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := req["model_aliases"].(map[string]any); ok {
|
||||
aliasMap = map[string]string{}
|
||||
for k, v := range raw {
|
||||
key := strings.TrimSpace(k)
|
||||
val := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
aliasMap[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
36
internal/admin/handler_settings_read.go
Normal file
36
internal/admin/handler_settings_read.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
|
||||
needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"admin": map[string]any{
|
||||
"has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
|
||||
"jwt_expire_hours": h.Store.AdminJWTExpireHours(),
|
||||
"jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
|
||||
"default_password_warning": authn.UsingDefaultAdminKey(h.Store),
|
||||
},
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
|
||||
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
},
|
||||
"toolcall": snap.Toolcall,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"needs_vercel_sync": needsSync,
|
||||
})
|
||||
}
|
||||
51
internal/admin/handler_settings_runtime.go
Normal file
51
internal/admin/handler_settings_runtime.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package admin
|
||||
|
||||
import "ds2api/internal/config"
|
||||
|
||||
func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
|
||||
merged := current
|
||||
if incoming != nil {
|
||||
if incoming.AccountMaxInflight > 0 {
|
||||
merged.AccountMaxInflight = incoming.AccountMaxInflight
|
||||
}
|
||||
if incoming.AccountMaxQueue > 0 {
|
||||
merged.AccountMaxQueue = incoming.AccountMaxQueue
|
||||
}
|
||||
if incoming.GlobalMaxInflight > 0 {
|
||||
merged.GlobalMaxInflight = incoming.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
return validateRuntimeSettings(merged)
|
||||
}
|
||||
|
||||
func (h *Handler) applyRuntimeSettings() {
|
||||
if h == nil || h.Store == nil || h.Pool == nil {
|
||||
return
|
||||
}
|
||||
accountCount := len(h.Store.Accounts())
|
||||
maxPer := h.Store.RuntimeAccountMaxInflight()
|
||||
recommended := defaultRuntimeRecommended(accountCount, maxPer)
|
||||
maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
|
||||
global := h.Store.RuntimeGlobalMaxInflight(recommended)
|
||||
h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
|
||||
}
|
||||
|
||||
func defaultRuntimeRecommended(accountCount, maxPer int) int {
|
||||
if maxPer <= 0 {
|
||||
maxPer = 1
|
||||
}
|
||||
if accountCount <= 0 {
|
||||
return maxPer
|
||||
}
|
||||
return accountCount * maxPer
|
||||
}
|
||||
|
||||
func settingsClaudeMapping(c config.Config) map[string]string {
|
||||
if len(c.ClaudeMapping) > 0 {
|
||||
return c.ClaudeMapping
|
||||
}
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
return c.ClaudeModelMap
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
119
internal/admin/handler_settings_write.go
Normal file
119
internal/admin/handler_settings_write.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
if runtimeCfg != nil {
|
||||
if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.Store.Update(func(c *config.Config) error {
|
||||
if adminCfg != nil {
|
||||
if adminCfg.JWTExpireHours > 0 {
|
||||
c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
|
||||
}
|
||||
}
|
||||
if runtimeCfg != nil {
|
||||
if runtimeCfg.AccountMaxInflight > 0 {
|
||||
c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
|
||||
}
|
||||
if runtimeCfg.AccountMaxQueue > 0 {
|
||||
c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
|
||||
}
|
||||
if runtimeCfg.GlobalMaxInflight > 0 {
|
||||
c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
if toolcallCfg != nil {
|
||||
if strings.TrimSpace(toolcallCfg.Mode) != "" {
|
||||
c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
|
||||
}
|
||||
if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
|
||||
c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
|
||||
}
|
||||
}
|
||||
if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
|
||||
c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
|
||||
}
|
||||
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
||||
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
}
|
||||
if aliasMap != nil {
|
||||
c.ModelAliases = aliasMap
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.applyRuntimeSettings()
|
||||
needsSync := config.IsVercel() || h.Store.IsEnvBacked()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": "settings updated and hot reloaded",
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"needs_vercel_sync": needsSync,
|
||||
"manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
newPassword := strings.TrimSpace(fieldString(req, "new_password"))
|
||||
if newPassword == "" {
|
||||
newPassword = strings.TrimSpace(fieldString(req, "password"))
|
||||
}
|
||||
if len(newPassword) < 4 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
hash := authn.HashAdminPassword(newPassword)
|
||||
if err := h.Store.Update(func(c *config.Config) error {
|
||||
c.Admin.PasswordHash = hash
|
||||
c.Admin.JWTValidAfterUnix = now
|
||||
return nil
|
||||
}); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": "password updated",
|
||||
"force_relogin": true,
|
||||
"jwt_valid_after_unix": now,
|
||||
})
|
||||
}
|
||||
24
internal/config/account.go
Normal file
24
internal/config/account.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
return strings.TrimSpace(a.Email)
|
||||
}
|
||||
if strings.TrimSpace(a.Mobile) != "" {
|
||||
return strings.TrimSpace(a.Mobile)
|
||||
}
|
||||
// Backward compatibility: old configs may contain token-only accounts.
|
||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||
token := strings.TrimSpace(a.Token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "token:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
241
internal/config/codec.go
Normal file
241
internal/config/codec.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]any{}
|
||||
for k, v := range c.AdditionalFields {
|
||||
m[k] = v
|
||||
}
|
||||
if len(c.Keys) > 0 {
|
||||
m["keys"] = c.Keys
|
||||
}
|
||||
if len(c.Accounts) > 0 {
|
||||
m["accounts"] = c.Accounts
|
||||
}
|
||||
if len(c.ClaudeMapping) > 0 {
|
||||
m["claude_mapping"] = c.ClaudeMapping
|
||||
}
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
m["claude_model_mapping"] = c.ClaudeModelMap
|
||||
}
|
||||
if len(c.ModelAliases) > 0 {
|
||||
m["model_aliases"] = c.ModelAliases
|
||||
}
|
||||
if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
|
||||
m["admin"] = c.Admin
|
||||
}
|
||||
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
|
||||
m["runtime"] = c.Runtime
|
||||
}
|
||||
if c.Compat.WideInputStrictOutput != nil {
|
||||
m["compat"] = c.Compat
|
||||
}
|
||||
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
|
||||
m["toolcall"] = c.Toolcall
|
||||
}
|
||||
if c.Responses.StoreTTLSeconds > 0 {
|
||||
m["responses"] = c.Responses
|
||||
}
|
||||
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
if c.VercelSyncTime != 0 {
|
||||
m["_vercel_sync_time"] = c.VercelSyncTime
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
raw := map[string]json.RawMessage{}
|
||||
if err := json.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
c.AdditionalFields = map[string]any{}
|
||||
for k, v := range raw {
|
||||
switch k {
|
||||
case "keys":
|
||||
if err := json.Unmarshal(v, &c.Keys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "accounts":
|
||||
if err := json.Unmarshal(v, &c.Accounts); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "claude_mapping":
|
||||
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "claude_model_mapping":
|
||||
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "model_aliases":
|
||||
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "admin":
|
||||
if err := json.Unmarshal(v, &c.Admin); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "runtime":
|
||||
if err := json.Unmarshal(v, &c.Runtime); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "compat":
|
||||
if err := json.Unmarshal(v, &c.Compat); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "toolcall":
|
||||
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "responses":
|
||||
if err := json.Unmarshal(v, &c.Responses); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "embeddings":
|
||||
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_time":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
default:
|
||||
var anyVal any
|
||||
if err := json.Unmarshal(v, &anyVal); err == nil {
|
||||
c.AdditionalFields[k] = anyVal
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
ModelAliases: cloneStringMap(c.ModelAliases),
|
||||
Admin: c.Admin,
|
||||
Runtime: c.Runtime,
|
||||
Compat: CompatConfig{
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
},
|
||||
Toolcall: c.Toolcall,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
}
|
||||
for k, v := range c.AdditionalFields {
|
||||
clone.AdditionalFields[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneBoolPtr(in *bool) *bool {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
v := *in
|
||||
return &v
|
||||
}
|
||||
|
||||
func parseConfigString(raw string) (Config, error) {
|
||||
var cfg Config
|
||||
candidates := []string{raw}
|
||||
if normalized := normalizeConfigInput(raw); normalized != raw {
|
||||
candidates = append(candidates, normalized)
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
|
||||
base64Input := candidates[len(candidates)-1]
|
||||
decoded, err := decodeConfigBase64(base64Input)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func normalizeConfigInput(raw string) string {
|
||||
normalized := strings.TrimSpace(raw)
|
||||
if normalized == "" {
|
||||
return normalized
|
||||
}
|
||||
for {
|
||||
changed := false
|
||||
if len(normalized) >= 2 {
|
||||
first := normalized[0]
|
||||
last := normalized[len(normalized)-1]
|
||||
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
|
||||
normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
|
||||
normalized = strings.TrimSpace(normalized[len("base64:"):])
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
break
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(normalized)
|
||||
}
|
||||
|
||||
func decodeConfigBase64(raw string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
var lastErr error
|
||||
for _, enc := range encodings {
|
||||
decoded, err := enc.DecodeString(raw)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errors.New("base64 decode failed")
|
||||
}
|
||||
@@ -1,63 +1,5 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var Logger = newLogger()
|
||||
|
||||
func newLogger() *slog.Logger {
|
||||
level := new(slog.LevelVar)
|
||||
switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
|
||||
case "DEBUG":
|
||||
level.Set(slog.LevelDebug)
|
||||
case "WARN":
|
||||
level.Set(slog.LevelWarn)
|
||||
case "ERROR":
|
||||
level.Set(slog.LevelError)
|
||||
default:
|
||||
level.Set(slog.LevelInfo)
|
||||
}
|
||||
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
|
||||
return slog.New(h)
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
return strings.TrimSpace(a.Email)
|
||||
}
|
||||
if strings.TrimSpace(a.Mobile) != "" {
|
||||
return strings.TrimSpace(a.Mobile)
|
||||
}
|
||||
// Backward compatibility: old configs may contain token-only accounts.
|
||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||
token := strings.TrimSpace(a.Token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "token:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
@@ -75,6 +17,13 @@ type Config struct {
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
type CompatConfig struct {
|
||||
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
|
||||
}
|
||||
@@ -103,641 +52,3 @@ type ResponsesConfig struct {
|
||||
type EmbeddingsConfig struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]any{}
|
||||
for k, v := range c.AdditionalFields {
|
||||
m[k] = v
|
||||
}
|
||||
if len(c.Keys) > 0 {
|
||||
m["keys"] = c.Keys
|
||||
}
|
||||
if len(c.Accounts) > 0 {
|
||||
m["accounts"] = c.Accounts
|
||||
}
|
||||
if len(c.ClaudeMapping) > 0 {
|
||||
m["claude_mapping"] = c.ClaudeMapping
|
||||
}
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
m["claude_model_mapping"] = c.ClaudeModelMap
|
||||
}
|
||||
if len(c.ModelAliases) > 0 {
|
||||
m["model_aliases"] = c.ModelAliases
|
||||
}
|
||||
if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
|
||||
m["admin"] = c.Admin
|
||||
}
|
||||
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
|
||||
m["runtime"] = c.Runtime
|
||||
}
|
||||
if c.Compat.WideInputStrictOutput != nil {
|
||||
m["compat"] = c.Compat
|
||||
}
|
||||
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
|
||||
m["toolcall"] = c.Toolcall
|
||||
}
|
||||
if c.Responses.StoreTTLSeconds > 0 {
|
||||
m["responses"] = c.Responses
|
||||
}
|
||||
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
if c.VercelSyncTime != 0 {
|
||||
m["_vercel_sync_time"] = c.VercelSyncTime
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
raw := map[string]json.RawMessage{}
|
||||
if err := json.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
c.AdditionalFields = map[string]any{}
|
||||
for k, v := range raw {
|
||||
switch k {
|
||||
case "keys":
|
||||
if err := json.Unmarshal(v, &c.Keys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "accounts":
|
||||
if err := json.Unmarshal(v, &c.Accounts); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "claude_mapping":
|
||||
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "claude_model_mapping":
|
||||
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "model_aliases":
|
||||
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "admin":
|
||||
if err := json.Unmarshal(v, &c.Admin); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "runtime":
|
||||
if err := json.Unmarshal(v, &c.Runtime); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "compat":
|
||||
if err := json.Unmarshal(v, &c.Compat); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "toolcall":
|
||||
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "responses":
|
||||
if err := json.Unmarshal(v, &c.Responses); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "embeddings":
|
||||
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_time":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
default:
|
||||
var anyVal any
|
||||
if err := json.Unmarshal(v, &anyVal); err == nil {
|
||||
c.AdditionalFields[k] = anyVal
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
ModelAliases: cloneStringMap(c.ModelAliases),
|
||||
Admin: c.Admin,
|
||||
Runtime: c.Runtime,
|
||||
Compat: CompatConfig{
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
},
|
||||
Toolcall: c.Toolcall,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
}
|
||||
for k, v := range c.AdditionalFields {
|
||||
clone.AdditionalFields[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneBoolPtr(in *bool) *bool {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
v := *in
|
||||
return &v
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
cfg Config
|
||||
path string
|
||||
fromEnv bool
|
||||
keyMap map[string]struct{} // O(1) API key lookup index
|
||||
accMap map[string]int // O(1) account lookup: identifier -> slice index
|
||||
}
|
||||
|
||||
func BaseDir() string {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return cwd
|
||||
}
|
||||
|
||||
func IsVercel() bool {
|
||||
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
|
||||
}
|
||||
|
||||
func ResolvePath(envKey, defaultRel string) string {
|
||||
raw := strings.TrimSpace(os.Getenv(envKey))
|
||||
if raw != "" {
|
||||
if filepath.IsAbs(raw) {
|
||||
return raw
|
||||
}
|
||||
return filepath.Join(BaseDir(), raw)
|
||||
}
|
||||
return filepath.Join(BaseDir(), defaultRel)
|
||||
}
|
||||
|
||||
func ConfigPath() string {
|
||||
return ResolvePath("DS2API_CONFIG_PATH", "config.json")
|
||||
}
|
||||
|
||||
func WASMPath() string {
|
||||
return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
|
||||
}
|
||||
|
||||
func StaticAdminDir() string {
|
||||
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
|
||||
}
|
||||
|
||||
func LoadStore() *Store {
|
||||
cfg, fromEnv, err := loadConfig()
|
||||
if err != nil {
|
||||
Logger.Warn("[config] load failed", "error", err)
|
||||
}
|
||||
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
|
||||
Logger.Warn("[config] empty config loaded")
|
||||
}
|
||||
s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
|
||||
s.rebuildIndexes()
|
||||
return s
|
||||
}
|
||||
|
||||
// rebuildIndexes must be called with the lock already held (or during init).
|
||||
func (s *Store) rebuildIndexes() {
|
||||
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
|
||||
for _, k := range s.cfg.Keys {
|
||||
s.keyMap[k] = struct{}{}
|
||||
}
|
||||
s.accMap = make(map[string]int, len(s.cfg.Accounts))
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id != "" {
|
||||
s.accMap[id] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadConfig() (Config, bool, error) {
|
||||
rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
|
||||
if rawCfg == "" {
|
||||
rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
|
||||
}
|
||||
if rawCfg != "" {
|
||||
cfg, err := parseConfigString(rawCfg)
|
||||
return cfg, true, err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(ConfigPath())
|
||||
if err != nil {
|
||||
if IsVercel() {
|
||||
// Vercel one-click deploy may start without a writable/present config file.
|
||||
// Keep an in-memory config so users can bootstrap via WebUI then sync env.
|
||||
return Config{}, true, nil
|
||||
}
|
||||
return Config{}, false, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
if IsVercel() {
|
||||
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
|
||||
return cfg, true, nil
|
||||
}
|
||||
return cfg, false, nil
|
||||
}
|
||||
|
||||
func parseConfigString(raw string) (Config, error) {
|
||||
var cfg Config
|
||||
candidates := []string{raw}
|
||||
if normalized := normalizeConfigInput(raw); normalized != raw {
|
||||
candidates = append(candidates, normalized)
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
|
||||
base64Input := candidates[len(candidates)-1]
|
||||
decoded, err := decodeConfigBase64(base64Input)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func normalizeConfigInput(raw string) string {
|
||||
normalized := strings.TrimSpace(raw)
|
||||
if normalized == "" {
|
||||
return normalized
|
||||
}
|
||||
for {
|
||||
changed := false
|
||||
if len(normalized) >= 2 {
|
||||
first := normalized[0]
|
||||
last := normalized[len(normalized)-1]
|
||||
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
|
||||
normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
|
||||
normalized = strings.TrimSpace(normalized[len("base64:"):])
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
break
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(normalized)
|
||||
}
|
||||
|
||||
func decodeConfigBase64(raw string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
var lastErr error
|
||||
for _, enc := range encodings {
|
||||
decoded, err := enc.DecodeString(raw)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errors.New("base64 decode failed")
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() Config {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.Clone()
|
||||
}
|
||||
|
||||
func (s *Store) HasAPIKey(k string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
_, ok := s.keyMap[k]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Store) Keys() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.cfg.Keys)
|
||||
}
|
||||
|
||||
func (s *Store) Accounts() []Account {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.cfg.Accounts)
|
||||
}
|
||||
|
||||
func (s *Store) FindAccount(identifier string) (Account, bool) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if idx, ok := s.findAccountIndexLocked(identifier); ok {
|
||||
return s.cfg.Accounts[idx], true
|
||||
}
|
||||
return Account{}, false
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
idx, ok := s.findAccountIndexLocked(identifier)
|
||||
if !ok {
|
||||
return errors.New("account not found")
|
||||
}
|
||||
oldID := s.cfg.Accounts[idx].Identifier()
|
||||
s.cfg.Accounts[idx].Token = token
|
||||
newID := s.cfg.Accounts[idx].Identifier()
|
||||
// Keep historical aliases usable for long-lived queues while also adding
|
||||
// the latest identifier after token refresh.
|
||||
if identifier != "" {
|
||||
s.accMap[identifier] = idx
|
||||
}
|
||||
if oldID != "" {
|
||||
s.accMap[oldID] = idx
|
||||
}
|
||||
if newID != "" {
|
||||
s.accMap[newID] = idx
|
||||
}
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Replace(cfg Config) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cfg = cfg.Clone()
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Update(mutator func(*Config) error) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.cfg.Clone()
|
||||
if err := mutator(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
s.cfg = cfg
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Save() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.fromEnv {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
if s.fromEnv {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
// findAccountIndexLocked expects the store lock to already be held.
|
||||
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
|
||||
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
|
||||
return idx, true
|
||||
}
|
||||
// Fallback for token-only accounts whose derived identifier changed after
|
||||
// a token refresh; this preserves correctness on map misses.
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
if acc.Identifier() == identifier {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
func (s *Store) IsEnvBacked() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.fromEnv
|
||||
}
|
||||
|
||||
func (s *Store) SetVercelSync(hash string, ts int64) error {
|
||||
return s.Update(func(c *Config) error {
|
||||
c.VercelSyncHash = hash
|
||||
c.VercelSyncTime = ts
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) ExportJSONAndBase64() (string, string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
b, err := json.Marshal(s.cfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return string(b), base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (s *Store) ClaudeMapping() map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if len(s.cfg.ClaudeModelMap) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeModelMap)
|
||||
}
|
||||
if len(s.cfg.ClaudeMapping) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeMapping)
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
|
||||
func (s *Store) ModelAliases() map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
out := DefaultModelAliases()
|
||||
for k, v := range s.cfg.ModelAliases {
|
||||
key := strings.TrimSpace(lower(k))
|
||||
val := strings.TrimSpace(lower(v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = val
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Store) CompatWideInputStrictOutput() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Compat.WideInputStrictOutput == nil {
|
||||
return true
|
||||
}
|
||||
return *s.cfg.Compat.WideInputStrictOutput
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallMode() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
|
||||
if mode == "" {
|
||||
return "feature_match"
|
||||
}
|
||||
return mode
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallEarlyEmitConfidence() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
|
||||
if level == "" {
|
||||
return "high"
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
func (s *Store) ResponsesStoreTTLSeconds() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Responses.StoreTTLSeconds > 0 {
|
||||
return s.cfg.Responses.StoreTTLSeconds
|
||||
}
|
||||
return 900
|
||||
}
|
||||
|
||||
func (s *Store) EmbeddingsProvider() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Embeddings.Provider)
|
||||
}
|
||||
|
||||
func (s *Store) AdminPasswordHash() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Admin.PasswordHash)
|
||||
}
|
||||
|
||||
func (s *Store) AdminJWTExpireHours() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Admin.JWTExpireHours > 0 {
|
||||
return s.cfg.Admin.JWTExpireHours
|
||||
}
|
||||
if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
|
||||
if n, err := strconv.Atoi(raw); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 24
|
||||
}
|
||||
|
||||
func (s *Store) AdminJWTValidAfterUnix() int64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.Admin.JWTValidAfterUnix
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeAccountMaxInflight() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.AccountMaxInflight > 0 {
|
||||
return s.cfg.Runtime.AccountMaxInflight
|
||||
}
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.AccountMaxQueue > 0 {
|
||||
return s.cfg.Runtime.AccountMaxQueue
|
||||
}
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.GlobalMaxInflight > 0 {
|
||||
return s.cfg.Runtime.GlobalMaxInflight
|
||||
}
|
||||
for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
25
internal/config/logger.go
Normal file
25
internal/config/logger.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var Logger = newLogger()
|
||||
|
||||
func newLogger() *slog.Logger {
|
||||
level := new(slog.LevelVar)
|
||||
switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
|
||||
case "DEBUG":
|
||||
level.Set(slog.LevelDebug)
|
||||
case "WARN":
|
||||
level.Set(slog.LevelWarn)
|
||||
case "ERROR":
|
||||
level.Set(slog.LevelError)
|
||||
default:
|
||||
level.Set(slog.LevelInfo)
|
||||
}
|
||||
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
|
||||
return slog.New(h)
|
||||
}
|
||||
42
internal/config/paths.go
Normal file
42
internal/config/paths.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func BaseDir() string {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return cwd
|
||||
}
|
||||
|
||||
func IsVercel() bool {
|
||||
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
|
||||
}
|
||||
|
||||
func ResolvePath(envKey, defaultRel string) string {
|
||||
raw := strings.TrimSpace(os.Getenv(envKey))
|
||||
if raw != "" {
|
||||
if filepath.IsAbs(raw) {
|
||||
return raw
|
||||
}
|
||||
return filepath.Join(BaseDir(), raw)
|
||||
}
|
||||
return filepath.Join(BaseDir(), defaultRel)
|
||||
}
|
||||
|
||||
func ConfigPath() string {
|
||||
return ResolvePath("DS2API_CONFIG_PATH", "config.json")
|
||||
}
|
||||
|
||||
func WASMPath() string {
|
||||
return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
|
||||
}
|
||||
|
||||
func StaticAdminDir() string {
|
||||
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
|
||||
}
|
||||
193
internal/config/store.go
Normal file
193
internal/config/store.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
cfg Config
|
||||
path string
|
||||
fromEnv bool
|
||||
keyMap map[string]struct{} // O(1) API key lookup index
|
||||
accMap map[string]int // O(1) account lookup: identifier -> slice index
|
||||
}
|
||||
|
||||
func LoadStore() *Store {
|
||||
cfg, fromEnv, err := loadConfig()
|
||||
if err != nil {
|
||||
Logger.Warn("[config] load failed", "error", err)
|
||||
}
|
||||
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
|
||||
Logger.Warn("[config] empty config loaded")
|
||||
}
|
||||
s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
|
||||
s.rebuildIndexes()
|
||||
return s
|
||||
}
|
||||
|
||||
func loadConfig() (Config, bool, error) {
|
||||
rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
|
||||
if rawCfg == "" {
|
||||
rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
|
||||
}
|
||||
if rawCfg != "" {
|
||||
cfg, err := parseConfigString(rawCfg)
|
||||
return cfg, true, err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(ConfigPath())
|
||||
if err != nil {
|
||||
if IsVercel() {
|
||||
// Vercel one-click deploy may start without a writable/present config file.
|
||||
// Keep an in-memory config so users can bootstrap via WebUI then sync env.
|
||||
return Config{}, true, nil
|
||||
}
|
||||
return Config{}, false, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
if IsVercel() {
|
||||
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
|
||||
return cfg, true, nil
|
||||
}
|
||||
return cfg, false, nil
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() Config {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.Clone()
|
||||
}
|
||||
|
||||
func (s *Store) HasAPIKey(k string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
_, ok := s.keyMap[k]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Store) Keys() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.cfg.Keys)
|
||||
}
|
||||
|
||||
func (s *Store) Accounts() []Account {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.cfg.Accounts)
|
||||
}
|
||||
|
||||
func (s *Store) FindAccount(identifier string) (Account, bool) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if idx, ok := s.findAccountIndexLocked(identifier); ok {
|
||||
return s.cfg.Accounts[idx], true
|
||||
}
|
||||
return Account{}, false
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
idx, ok := s.findAccountIndexLocked(identifier)
|
||||
if !ok {
|
||||
return errors.New("account not found")
|
||||
}
|
||||
oldID := s.cfg.Accounts[idx].Identifier()
|
||||
s.cfg.Accounts[idx].Token = token
|
||||
newID := s.cfg.Accounts[idx].Identifier()
|
||||
// Keep historical aliases usable for long-lived queues while also adding
|
||||
// the latest identifier after token refresh.
|
||||
if identifier != "" {
|
||||
s.accMap[identifier] = idx
|
||||
}
|
||||
if oldID != "" {
|
||||
s.accMap[oldID] = idx
|
||||
}
|
||||
if newID != "" {
|
||||
s.accMap[newID] = idx
|
||||
}
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Replace(cfg Config) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cfg = cfg.Clone()
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Update(mutator func(*Config) error) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.cfg.Clone()
|
||||
if err := mutator(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
s.cfg = cfg
|
||||
s.rebuildIndexes()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *Store) Save() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.fromEnv {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
if s.fromEnv {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
}
|
||||
|
||||
func (s *Store) IsEnvBacked() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.fromEnv
|
||||
}
|
||||
|
||||
func (s *Store) SetVercelSync(hash string, ts int64) error {
|
||||
return s.Update(func(c *Config) error {
|
||||
c.VercelSyncHash = hash
|
||||
c.VercelSyncTime = ts
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) ExportJSONAndBase64() (string, string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
b, err := json.Marshal(s.cfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return string(b), base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
167
internal/config/store_accessors.go
Normal file
167
internal/config/store_accessors.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Store) ClaudeMapping() map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if len(s.cfg.ClaudeModelMap) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeModelMap)
|
||||
}
|
||||
if len(s.cfg.ClaudeMapping) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeMapping)
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
|
||||
func (s *Store) ModelAliases() map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
out := DefaultModelAliases()
|
||||
for k, v := range s.cfg.ModelAliases {
|
||||
key := strings.TrimSpace(lower(k))
|
||||
val := strings.TrimSpace(lower(v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = val
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Store) CompatWideInputStrictOutput() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Compat.WideInputStrictOutput == nil {
|
||||
return true
|
||||
}
|
||||
return *s.cfg.Compat.WideInputStrictOutput
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallMode() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
|
||||
if mode == "" {
|
||||
return "feature_match"
|
||||
}
|
||||
return mode
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallEarlyEmitConfidence() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
|
||||
if level == "" {
|
||||
return "high"
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
func (s *Store) ResponsesStoreTTLSeconds() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Responses.StoreTTLSeconds > 0 {
|
||||
return s.cfg.Responses.StoreTTLSeconds
|
||||
}
|
||||
return 900
|
||||
}
|
||||
|
||||
func (s *Store) EmbeddingsProvider() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Embeddings.Provider)
|
||||
}
|
||||
|
||||
func (s *Store) AdminPasswordHash() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Admin.PasswordHash)
|
||||
}
|
||||
|
||||
func (s *Store) AdminJWTExpireHours() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Admin.JWTExpireHours > 0 {
|
||||
return s.cfg.Admin.JWTExpireHours
|
||||
}
|
||||
if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
|
||||
if n, err := strconv.Atoi(raw); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 24
|
||||
}
|
||||
|
||||
func (s *Store) AdminJWTValidAfterUnix() int64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.Admin.JWTValidAfterUnix
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeAccountMaxInflight() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.AccountMaxInflight > 0 {
|
||||
return s.cfg.Runtime.AccountMaxInflight
|
||||
}
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.AccountMaxQueue > 0 {
|
||||
return s.cfg.Runtime.AccountMaxQueue
|
||||
}
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.GlobalMaxInflight > 0 {
|
||||
return s.cfg.Runtime.GlobalMaxInflight
|
||||
}
|
||||
for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
31
internal/config/store_index.go
Normal file
31
internal/config/store_index.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package config
|
||||
|
||||
// rebuildIndexes must be called with the lock already held (or during init).
|
||||
func (s *Store) rebuildIndexes() {
|
||||
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
|
||||
for _, k := range s.cfg.Keys {
|
||||
s.keyMap[k] = struct{}{}
|
||||
}
|
||||
s.accMap = make(map[string]int, len(s.cfg.Accounts))
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id != "" {
|
||||
s.accMap[id] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findAccountIndexLocked expects the store lock to already be held.
|
||||
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
|
||||
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
|
||||
return idx, true
|
||||
}
|
||||
// Fallback for token-only accounts whose derived identifier changed after
|
||||
// a token refresh; this preserves correctness on map misses.
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
if acc.Identifier() == identifier {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
@@ -1,347 +0,0 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
"ds2api/internal/devcapture"
|
||||
"ds2api/internal/util"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
)
|
||||
|
||||
// intFrom is a package-internal alias for the shared util version.
|
||||
var intFrom = util.IntFrom
|
||||
|
||||
type Client struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
capture *devcapture.Store
|
||||
regular trans.Doer
|
||||
stream trans.Doer
|
||||
fallback *http.Client
|
||||
fallbackS *http.Client
|
||||
powSolver *PowSolver
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
|
||||
return &Client{
|
||||
Store: store,
|
||||
Auth: resolver,
|
||||
capture: devcapture.Global(),
|
||||
regular: trans.New(60 * time.Second),
|
||||
stream: trans.New(0),
|
||||
fallback: &http.Client{Timeout: 60 * time.Second},
|
||||
fallbackS: &http.Client{Timeout: 0},
|
||||
powSolver: NewPowSolver(config.WASMPath()),
|
||||
maxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) PreloadPow(ctx context.Context) error {
|
||||
return c.powSolver.init(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
|
||||
payload := map[string]any{
|
||||
"password": strings.TrimSpace(acc.Password),
|
||||
"device_id": "deepseek_to_api",
|
||||
"os": "android",
|
||||
}
|
||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||
payload["email"] = email
|
||||
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
||||
payload["mobile"] = mobile
|
||||
payload["area_code"] = nil
|
||||
} else {
|
||||
return "", errors.New("missing email/mobile")
|
||||
}
|
||||
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if code != 0 {
|
||||
return "", fmt.Errorf("login failed: %v", resp["msg"])
|
||||
}
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
if intFrom(data["biz_code"]) != 0 {
|
||||
return "", fmt.Errorf("login failed: %v", data["biz_msg"])
|
||||
}
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
user, _ := bizData["user"].(map[string]any)
|
||||
token, _ := user["token"].(string)
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return "", errors.New("missing login token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
sessionID, _ := bizData["id"].(string)
|
||||
if sessionID != "" {
|
||||
return sessionID, nil
|
||||
}
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return "", errors.New("create session failed")
|
||||
}
|
||||
|
||||
func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
challenge, _ := bizData["challenge"].(map[string]any)
|
||||
answer, err := c.powSolver.Compute(ctx, challenge)
|
||||
if err != nil {
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
return BuildPowHeader(challenge, answer)
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return "", errors.New("get pow failed")
|
||||
}
|
||||
|
||||
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
|
||||
if err != nil {
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return nil, errors.New("completion failed")
|
||||
}
|
||||
|
||||
func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
|
||||
body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status == 0 {
|
||||
return nil, errors.New("request failed")
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := doer.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
payloadBytes, err := readResponseBody(resp)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if len(payloadBytes) > 0 {
|
||||
if err := json.Unmarshal(payloadBytes, &out); err != nil {
|
||||
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
|
||||
}
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := c.stream.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
return c.fallbackS.Do(req2)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) authHeaders(token string) map[string]string {
|
||||
headers := make(map[string]string, len(BaseHeaders)+1)
|
||||
for k, v := range BaseHeaders {
|
||||
headers[k] = v
|
||||
}
|
||||
headers["authorization"] = "Bearer " + token
|
||||
return headers
|
||||
}
|
||||
|
||||
func isTokenInvalid(status int, code int, msg string) bool {
|
||||
msg = strings.ToLower(msg)
|
||||
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||
return true
|
||||
}
|
||||
if code == 40001 || code == 40002 || code == 40003 {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||
}
|
||||
|
||||
func readResponseBody(resp *http.Response) ([]byte, error) {
|
||||
encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
|
||||
var reader io.Reader = resp.Body
|
||||
switch encoding {
|
||||
case "gzip":
|
||||
gz, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gz.Close()
|
||||
reader = gz
|
||||
case "br":
|
||||
reader = brotli.NewReader(resp.Body)
|
||||
}
|
||||
return io.ReadAll(reader)
|
||||
}
|
||||
|
||||
func preview(b []byte) string {
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > 160 {
|
||||
return s[:160]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
if !onLine(scanner.Bytes()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
153
internal/deepseek/client_auth.go
Normal file
153
internal/deepseek/client_auth.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
|
||||
payload := map[string]any{
|
||||
"password": strings.TrimSpace(acc.Password),
|
||||
"device_id": "deepseek_to_api",
|
||||
"os": "android",
|
||||
}
|
||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||
payload["email"] = email
|
||||
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
||||
payload["mobile"] = mobile
|
||||
payload["area_code"] = nil
|
||||
} else {
|
||||
return "", errors.New("missing email/mobile")
|
||||
}
|
||||
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if code != 0 {
|
||||
return "", fmt.Errorf("login failed: %v", resp["msg"])
|
||||
}
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
if intFrom(data["biz_code"]) != 0 {
|
||||
return "", fmt.Errorf("login failed: %v", data["biz_msg"])
|
||||
}
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
user, _ := bizData["user"].(map[string]any)
|
||||
token, _ := user["token"].(string)
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return "", errors.New("missing login token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
sessionID, _ := bizData["id"].(string)
|
||||
if sessionID != "" {
|
||||
return sessionID, nil
|
||||
}
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return "", errors.New("create session failed")
|
||||
}
|
||||
|
||||
func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
challenge, _ := bizData["challenge"].(map[string]any)
|
||||
answer, err := c.powSolver.Compute(ctx, challenge)
|
||||
if err != nil {
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
return BuildPowHeader(challenge, answer)
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
return "", errors.New("get pow failed")
|
||||
}
|
||||
|
||||
func (c *Client) authHeaders(token string) map[string]string {
|
||||
headers := make(map[string]string, len(BaseHeaders)+1)
|
||||
for k, v := range BaseHeaders {
|
||||
headers[k] = v
|
||||
}
|
||||
headers["authorization"] = "Bearer " + token
|
||||
return headers
|
||||
}
|
||||
|
||||
func isTokenInvalid(status int, code int, msg string) bool {
|
||||
msg = strings.ToLower(msg)
|
||||
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||
return true
|
||||
}
|
||||
if code == 40001 || code == 40002 || code == 40003 {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||
}
|
||||
71
internal/deepseek/client_completion.go
Normal file
71
internal/deepseek/client_completion.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
|
||||
if err != nil {
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return nil, errors.New("completion failed")
|
||||
}
|
||||
|
||||
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := c.stream.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
return c.fallbackS.Do(req2)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
46
internal/deepseek/client_core.go
Normal file
46
internal/deepseek/client_core.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
"ds2api/internal/devcapture"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// intFrom is a package-internal alias for the shared util version.
|
||||
var intFrom = util.IntFrom
|
||||
|
||||
type Client struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
capture *devcapture.Store
|
||||
regular trans.Doer
|
||||
stream trans.Doer
|
||||
fallback *http.Client
|
||||
fallbackS *http.Client
|
||||
powSolver *PowSolver
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
|
||||
return &Client{
|
||||
Store: store,
|
||||
Auth: resolver,
|
||||
capture: devcapture.Global(),
|
||||
regular: trans.New(60 * time.Second),
|
||||
stream: trans.New(0),
|
||||
fallback: &http.Client{Timeout: 60 * time.Second},
|
||||
fallbackS: &http.Client{Timeout: 0},
|
||||
powSolver: NewPowSolver(config.WASMPath()),
|
||||
maxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) PreloadPow(ctx context.Context) error {
|
||||
return c.powSolver.init(ctx)
|
||||
}
|
||||
51
internal/deepseek/client_http_helpers.go
Normal file
51
internal/deepseek/client_http_helpers.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
)
|
||||
|
||||
func readResponseBody(resp *http.Response) ([]byte, error) {
|
||||
encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
|
||||
var reader io.Reader = resp.Body
|
||||
switch encoding {
|
||||
case "gzip":
|
||||
gz, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gz.Close()
|
||||
reader = gz
|
||||
case "br":
|
||||
reader = brotli.NewReader(resp.Body)
|
||||
}
|
||||
return io.ReadAll(reader)
|
||||
}
|
||||
|
||||
func preview(b []byte) string {
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > 160 {
|
||||
return s[:160]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
if !onLine(scanner.Bytes()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
64
internal/deepseek/client_http_json.go
Normal file
64
internal/deepseek/client_http_json.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
)
|
||||
|
||||
func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
|
||||
body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status == 0 {
|
||||
return nil, errors.New("request failed")
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := doer.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
payloadBytes, err := readResponseBody(resp)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if len(payloadBytes) > 0 {
|
||||
if err := json.Unmarshal(payloadBytes, &out); err != nil {
|
||||
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
|
||||
}
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
@@ -1,307 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
|
||||
return map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Align responses tool-call semantics with chat/completions:
|
||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
output = append(output, map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
formatted := util.FormatOpenAIToolCalls(detected)
|
||||
output = append(output, toResponsesFunctionCallItems(formatted)...)
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": formatted,
|
||||
})
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
content = append([]any{map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
}}, content...)
|
||||
}
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
})
|
||||
}
|
||||
if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
|
||||
exposedOutputText = finalThinking
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "message",
|
||||
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"id": responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"output_text": exposedOutputText,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
callID, _ := tc["id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
name := ""
|
||||
args := "{}"
|
||||
if fn, ok := tc["function"].(map[string]any); ok {
|
||||
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
|
||||
name = n
|
||||
}
|
||||
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
|
||||
args = a
|
||||
}
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(args),
|
||||
"status": "completed",
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeJSONString(raw string) string {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
var v any
|
||||
if err := json.Unmarshal([]byte(s), &v); err != nil {
|
||||
return raw
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
"index": index,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": map[string]any{},
|
||||
"index": index,
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
}
|
||||
if len(usage) > 0 {
|
||||
out["usage"] = usage
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"call_id": callID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
responseID, _ := response["id"].(string)
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
"response_id": responseID,
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
60
internal/format/openai/render_chat.go
Normal file
60
internal/format/openai/render_chat.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": BuildChatUsage(finalPrompt, finalThinking, finalText),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
"index": index,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": map[string]any{},
|
||||
"index": index,
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
}
|
||||
if len(usage) > 0 {
|
||||
out["usage"] = usage
|
||||
}
|
||||
return out
|
||||
}
|
||||
119
internal/format/openai/render_responses.go
Normal file
119
internal/format/openai/render_responses.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Align responses tool-call semantics with chat/completions:
|
||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
output = append(output, map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
formatted := util.FormatOpenAIToolCalls(detected)
|
||||
output = append(output, toResponsesFunctionCallItems(formatted)...)
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": formatted,
|
||||
})
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
content = append([]any{map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
}}, content...)
|
||||
}
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
})
|
||||
}
|
||||
if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
|
||||
exposedOutputText = finalThinking
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "message",
|
||||
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"id": responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"output_text": exposedOutputText,
|
||||
"usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText),
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
callID, _ := tc["id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
name := ""
|
||||
args := "{}"
|
||||
if fn, ok := tc["function"].(map[string]any); ok {
|
||||
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
|
||||
name = n
|
||||
}
|
||||
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
|
||||
args = a
|
||||
}
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(args),
|
||||
"status": "completed",
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeJSONString(raw string) string {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
var v any
|
||||
if err := json.Unmarshal([]byte(s), &v); err != nil {
|
||||
return raw
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
106
internal/format/openai/render_stream_events.go
Normal file
106
internal/format/openai/render_stream_events.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package openai
|
||||
|
||||
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"call_id": callID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
responseID, _ := response["id"].(string)
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
"response_id": responseID,
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
28
internal/format/openai/render_usage.go
Normal file
28
internal/format/openai/render_usage.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package openai
|
||||
|
||||
import "ds2api/internal/util"
|
||||
|
||||
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -125,72 +124,6 @@ func (r *Runner) caseStreamAbortRelease(ctx context.Context, cc *caseContext) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error {
|
||||
cc.seq++
|
||||
traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq)
|
||||
cc.traceIDsSet[traceID] = struct{}{}
|
||||
fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := map[string]string{}
|
||||
for k, v := range spec.Headers {
|
||||
headers[k] = v
|
||||
}
|
||||
headers["X-Ds2-Test-Trace"] = traceID
|
||||
bodyBytes, _ := json.Marshal(spec.Body)
|
||||
headers["Content-Type"] = "application/json"
|
||||
cc.requests = append(cc.requests, requestLog{
|
||||
Seq: cc.seq,
|
||||
Attempt: 1,
|
||||
TraceID: traceID,
|
||||
Method: spec.Method,
|
||||
URL: fullURL,
|
||||
Headers: headers,
|
||||
Body: spec.Body,
|
||||
Timestamp: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := cc.runner.httpClient.Do(req)
|
||||
if err != nil {
|
||||
cc.responses = append(cc.responses, responseLog{
|
||||
Seq: cc.seq,
|
||||
Attempt: 1,
|
||||
TraceID: traceID,
|
||||
StatusCode: 0,
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
NetworkErr: err.Error(),
|
||||
ReceivedAt: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf := make([]byte, 512)
|
||||
_, _ = resp.Body.Read(buf)
|
||||
_ = resp.Body.Close()
|
||||
cc.responses = append(cc.responses, responseLog{
|
||||
Seq: cc.seq,
|
||||
Attempt: 1,
|
||||
TraceID: traceID,
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header,
|
||||
BodyText: "aborted_after_first_chunk",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
ReceivedAt: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseToolcallStreamMixed(ctx context.Context, cc *caseContext) error {
|
||||
payload := toolcallPayload(true)
|
||||
payload["messages"] = []map[string]any{
|
||||
@@ -293,167 +226,6 @@ func (r *Runner) caseSSEJSONIntegrity(ctx context.Context, cc *caseContext) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.requestOnce(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-not-exists",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hi"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: false,
|
||||
}, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
e, _ := m["error"].(map[string]any)
|
||||
cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
e, _ := m["error"].(map[string]any)
|
||||
cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/config",
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error {
|
||||
if len(r.configRaw.Accounts) == 0 {
|
||||
cc.assert("account_present", false, "no account in config")
|
||||
return nil
|
||||
}
|
||||
acc := r.configRaw.Accounts[0]
|
||||
id := strings.TrimSpace(acc.Email)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(acc.Mobile)
|
||||
}
|
||||
if id == "" {
|
||||
cc.assert("account_identifier", false, "first account has no identifier")
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty")
|
||||
cc.assert("account_password_present", true, "skipped strict refresh check due empty password")
|
||||
return nil
|
||||
}
|
||||
invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID)
|
||||
update := map[string]any{
|
||||
"keys": r.configRaw.Keys,
|
||||
"accounts": []map[string]any{
|
||||
{
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"password": acc.Password,
|
||||
"token": invalidToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
updResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/admin/config",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Body: update,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode))
|
||||
|
||||
chatResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
"X-Ds2-Target-Account": id,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "token refresh test"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body)))
|
||||
|
||||
cfgResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/config",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
_ = json.Unmarshal(cfgResp.Body, &cfg)
|
||||
accounts, _ := cfg["accounts"].([]any)
|
||||
preview := ""
|
||||
hasToken := false
|
||||
for _, item := range accounts {
|
||||
m, _ := item.(map[string]any)
|
||||
e := asString(m["email"])
|
||||
mo := asString(m["mobile"])
|
||||
if e == acc.Email && mo == acc.Mobile {
|
||||
preview = asString(m["token_preview"])
|
||||
hasToken, _ = m["has_token"].(bool)
|
||||
break
|
||||
}
|
||||
}
|
||||
cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body)))
|
||||
cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20]))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) fetchQueueStatus(ctx context.Context, cc *caseContext) (map[string]any, error) {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
|
||||
76
internal/testsuite/edge_cases_abort.go
Normal file
76
internal/testsuite/edge_cases_abort.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error {
|
||||
cc.seq++
|
||||
traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq)
|
||||
cc.traceIDsSet[traceID] = struct{}{}
|
||||
fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := map[string]string{}
|
||||
for k, v := range spec.Headers {
|
||||
headers[k] = v
|
||||
}
|
||||
headers["X-Ds2-Test-Trace"] = traceID
|
||||
bodyBytes, _ := json.Marshal(spec.Body)
|
||||
headers["Content-Type"] = "application/json"
|
||||
cc.requests = append(cc.requests, requestLog{
|
||||
Seq: cc.seq,
|
||||
Attempt: 1,
|
||||
TraceID: traceID,
|
||||
Method: spec.Method,
|
||||
URL: fullURL,
|
||||
Headers: headers,
|
||||
Body: spec.Body,
|
||||
Timestamp: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := cc.runner.httpClient.Do(req)
|
||||
if err != nil {
|
||||
cc.responses = append(cc.responses, responseLog{
|
||||
Seq: cc.seq,
|
||||
Attempt: 1,
|
||||
TraceID: traceID,
|
||||
StatusCode: 0,
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
NetworkErr: err.Error(),
|
||||
ReceivedAt: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf := make([]byte, 512)
|
||||
_, _ = resp.Body.Read(buf)
|
||||
_ = resp.Body.Close()
|
||||
cc.responses = append(cc.responses, responseLog{
|
||||
Seq: cc.seq,
|
||||
Attempt: 1,
|
||||
TraceID: traceID,
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header,
|
||||
BodyText: "aborted_after_first_chunk",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
ReceivedAt: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
170
internal/testsuite/edge_cases_error_contract.go
Normal file
170
internal/testsuite/edge_cases_error_contract.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.requestOnce(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-not-exists",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hi"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: false,
|
||||
}, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
e, _ := m["error"].(map[string]any)
|
||||
cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
e, _ := m["error"].(map[string]any)
|
||||
cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/config",
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error {
|
||||
if len(r.configRaw.Accounts) == 0 {
|
||||
cc.assert("account_present", false, "no account in config")
|
||||
return nil
|
||||
}
|
||||
acc := r.configRaw.Accounts[0]
|
||||
id := strings.TrimSpace(acc.Email)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(acc.Mobile)
|
||||
}
|
||||
if id == "" {
|
||||
cc.assert("account_identifier", false, "first account has no identifier")
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty")
|
||||
cc.assert("account_password_present", true, "skipped strict refresh check due empty password")
|
||||
return nil
|
||||
}
|
||||
invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID)
|
||||
update := map[string]any{
|
||||
"keys": r.configRaw.Keys,
|
||||
"accounts": []map[string]any{
|
||||
{
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"password": acc.Password,
|
||||
"token": invalidToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
updResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/admin/config",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Body: update,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode))
|
||||
|
||||
chatResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
"X-Ds2-Target-Account": id,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "token refresh test"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body)))
|
||||
|
||||
cfgResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/config",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
_ = json.Unmarshal(cfgResp.Body, &cfg)
|
||||
accounts, _ := cfg["accounts"].([]any)
|
||||
preview := ""
|
||||
hasToken := false
|
||||
for _, item := range accounts {
|
||||
m, _ := item.(map[string]any)
|
||||
e := asString(m["email"])
|
||||
mo := asString(m["mobile"])
|
||||
if e == acc.Email && mo == acc.Mobile {
|
||||
preview = asString(m["token_preview"])
|
||||
hasToken, _ = m["has_token"].(bool)
|
||||
break
|
||||
}
|
||||
}
|
||||
cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body)))
|
||||
cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20]))
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
161
internal/testsuite/runner_cases_admin.go
Normal file
161
internal/testsuite/runner_cases_admin.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error {
|
||||
loginResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/admin/login",
|
||||
Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode))
|
||||
var payload map[string]any
|
||||
_ = json.Unmarshal(loginResp.Body, &payload)
|
||||
token := asString(payload["token"])
|
||||
cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body)))
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
verifyResp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/verify",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + token,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode))
|
||||
var v map[string]any
|
||||
_ = json.Unmarshal(verifyResp.Body, &v)
|
||||
valid, _ := v["valid"].(bool)
|
||||
cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/queue/status",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
_, hasRec := m["recommended_concurrency"]
|
||||
_, hasQueue := m["max_queue_size"]
|
||||
cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error {
|
||||
if strings.TrimSpace(r.accountID) == "" {
|
||||
cc.assert("account_present", false, "no account in config")
|
||||
return nil
|
||||
}
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/admin/accounts/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"identifier": r.accountID,
|
||||
"model": "deepseek-chat",
|
||||
"message": "ping",
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
ok, _ := m["success"].(bool)
|
||||
cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error {
|
||||
k := "testsuite-temp-" + sanitizeID(r.runID)
|
||||
add, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/admin/keys",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Body: map[string]any{"key": k},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode))
|
||||
|
||||
cfg1, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/config",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
containsAdded := strings.Contains(string(cfg1.Body), k)
|
||||
cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config")
|
||||
|
||||
delPath := "/admin/keys/" + url.PathEscape(k)
|
||||
del, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodDelete,
|
||||
Path: delPath,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode))
|
||||
|
||||
cfg2, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/admin/config",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.adminJWT,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present")
|
||||
|
||||
if err := r.ensureOriginalConfigUntouched(); err != nil {
|
||||
cc.assert("original_config_unchanged", false, err.Error())
|
||||
} else {
|
||||
cc.assert("original_config_unchanged", true, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
103
internal/testsuite/runner_cases_claude.go
Normal file
103
internal/testsuite/runner_cases_claude.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
ids := extractModelIDs(resp.Body)
|
||||
cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids))
|
||||
return nil
|
||||
}
|
||||
func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/anthropic/v1/messages",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hello"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/anthropic/v1/messages",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "stream hello"},
|
||||
},
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
events := parseClaudeStreamEvents(resp.Body)
|
||||
cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events))
|
||||
cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/anthropic/v1/messages/count_tokens",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "count me"},
|
||||
},
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
v := toInt(m["input_tokens"])
|
||||
cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
221
internal/testsuite/runner_cases_openai.go
Normal file
221
internal/testsuite/runner_cases_openai.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
ids := extractModelIDs(resp.Body)
|
||||
cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ","))
|
||||
cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ","))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "请简单回复一句话"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
choices, _ := m["choices"].([]any)
|
||||
cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "请流式回复一句话"},
|
||||
},
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
frames, done := parseSSEFrames(resp.Body)
|
||||
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
|
||||
cc.assert("done_terminated", done, "expected [DONE]")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/responses",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "请简要回答 hello",
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
responseID := asString(m["id"])
|
||||
cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
if responseID != "" {
|
||||
getResp, getErr := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/v1/responses/" + responseID,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if getErr != nil {
|
||||
return getErr
|
||||
}
|
||||
cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/responses",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "请流式回答 hello",
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
frames, done := parseSSEFrames(resp.Body)
|
||||
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
|
||||
hasCreated := false
|
||||
hasCompleted := false
|
||||
for _, f := range frames {
|
||||
switch asString(f["type"]) {
|
||||
case "response.created":
|
||||
hasCreated = true
|
||||
case "response.completed":
|
||||
hasCompleted = true
|
||||
}
|
||||
}
|
||||
cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("done_terminated", done, "expected [DONE]")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/embeddings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": []string{"hello", "world"},
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
data, _ := m["data"].([]any)
|
||||
cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
errObj, _ := m["error"].(map[string]any)
|
||||
_, hasCode := errObj["code"]
|
||||
_, hasParam := errObj["param"]
|
||||
cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
236
internal/testsuite/runner_cases_openai_advanced.go
Normal file
236
internal/testsuite/runner_cases_openai_advanced.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-reasoner",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "先思考后回答:1+1"},
|
||||
},
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
frames, done := parseSSEFrames(resp.Body)
|
||||
hasReasoning := false
|
||||
for _, f := range frames {
|
||||
choices, _ := f["choices"].([]any)
|
||||
for _, c := range choices {
|
||||
ch, _ := c.(map[string]any)
|
||||
delta, _ := ch["delta"].(map[string]any)
|
||||
if asString(delta["reasoning_content"]) != "" {
|
||||
hasReasoning = true
|
||||
}
|
||||
}
|
||||
}
|
||||
cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found")
|
||||
cc.assert("done_terminated", done, "expected [DONE]")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: toolcallPayload(false),
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
choices, _ := m["choices"].([]any)
|
||||
if len(choices) == 0 {
|
||||
cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
c0, _ := choices[0].(map[string]any)
|
||||
cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
msg, _ := c0["message"].(map[string]any)
|
||||
tc, _ := msg["tool_calls"].([]any)
|
||||
cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: toolcallPayload(true),
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
frames, done := parseSSEFrames(resp.Body)
|
||||
hasTool := false
|
||||
rawLeak := false
|
||||
for _, f := range frames {
|
||||
choices, _ := f["choices"].([]any)
|
||||
for _, c := range choices {
|
||||
ch, _ := c.(map[string]any)
|
||||
delta, _ := ch["delta"].(map[string]any)
|
||||
if _, ok := delta["tool_calls"]; ok {
|
||||
hasTool = true
|
||||
}
|
||||
content := asString(delta["content"])
|
||||
if strings.Contains(strings.ToLower(content), `"tool_calls"`) {
|
||||
rawLeak = true
|
||||
}
|
||||
}
|
||||
}
|
||||
cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing")
|
||||
cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content")
|
||||
cc.assert("done_terminated", done, "expected [DONE]")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error {
|
||||
accountCount := len(r.configRaw.Accounts)
|
||||
n := accountCount*2 + 2
|
||||
if n < 2 {
|
||||
n = 2
|
||||
}
|
||||
type one struct {
|
||||
Status int
|
||||
Err string
|
||||
}
|
||||
results := make([]one, n)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)},
|
||||
},
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
results[idx] = one{Err: err.Error()}
|
||||
return
|
||||
}
|
||||
results[idx] = one{Status: resp.StatusCode}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
dist := map[int]int{}
|
||||
success := 0
|
||||
for _, it := range results {
|
||||
if it.Status > 0 {
|
||||
dist[it.Status]++
|
||||
if it.Status == http.StatusOK {
|
||||
success++
|
||||
}
|
||||
}
|
||||
}
|
||||
cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist))
|
||||
_, has5xx := has5xx(dist)
|
||||
cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist))
|
||||
if err := r.ping("/healthz"); err != nil {
|
||||
cc.assert("server_alive", false, err.Error())
|
||||
} else {
|
||||
cc.assert("server_alive", true, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/chat/completions",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID),
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hi"},
|
||||
},
|
||||
"stream": false,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
e, _ := m["error"].(map[string]any)
|
||||
cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func toolcallPayload(stream bool) map[string]any {
|
||||
return map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你必须调用工具 search 查询 golang,并仅返回工具调用。",
|
||||
},
|
||||
},
|
||||
"tools": []map[string]any{
|
||||
{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"description": "search documents",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"q": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []string{"q"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": stream,
|
||||
}
|
||||
}
|
||||
290
internal/testsuite/runner_core.go
Normal file
290
internal/testsuite/runner_core.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
ConfigPath string
|
||||
AdminKey string
|
||||
OutputDir string
|
||||
Port int
|
||||
Timeout time.Duration
|
||||
Retries int
|
||||
NoPreflight bool
|
||||
MaxKeepRuns int
|
||||
}
|
||||
|
||||
type runSummary struct {
|
||||
RunID string `json:"run_id"`
|
||||
StartedAt string `json:"started_at"`
|
||||
EndedAt string `json:"ended_at"`
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
Stats map[string]any `json:"stats"`
|
||||
Environment map[string]any `json:"environment"`
|
||||
Cases []caseResult `json:"cases"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
}
|
||||
|
||||
type caseResult struct {
|
||||
CaseID string `json:"case_id"`
|
||||
Passed bool `json:"passed"`
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
TraceIDs []string `json:"trace_ids"`
|
||||
StatusCodes []int `json:"status_codes"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ArtifactPath string `json:"artifact_path"`
|
||||
Assertions []assertionResult `json:"assertions"`
|
||||
}
|
||||
|
||||
type assertionResult struct {
|
||||
Name string `json:"name"`
|
||||
Passed bool `json:"passed"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type requestLog struct {
|
||||
Seq int `json:"seq"`
|
||||
Attempt int `json:"attempt"`
|
||||
TraceID string `json:"trace_id"`
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body any `json:"body,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}
|
||||
|
||||
type responseLog struct {
|
||||
Seq int `json:"seq"`
|
||||
Attempt int `json:"attempt"`
|
||||
TraceID string `json:"trace_id"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
BodyText string `json:"body_text"`
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
NetworkErr string `json:"network_error,omitempty"`
|
||||
ReceivedAt string `json:"received_at"`
|
||||
}
|
||||
|
||||
type caseContext struct {
|
||||
runner *Runner
|
||||
id string
|
||||
dir string
|
||||
startedAt time.Time
|
||||
mu sync.Mutex
|
||||
seq int
|
||||
assertions []assertionResult
|
||||
requests []requestLog
|
||||
responses []responseLog
|
||||
streamRaw strings.Builder
|
||||
traceIDsSet map[string]struct{}
|
||||
}
|
||||
|
||||
type requestSpec struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string]string
|
||||
Body any
|
||||
Stream bool
|
||||
Retryable bool
|
||||
}
|
||||
|
||||
type responseResult struct {
|
||||
StatusCode int
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
TraceID string
|
||||
URL string
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
opts Options
|
||||
|
||||
runID string
|
||||
runDir string
|
||||
serverLog string
|
||||
preflightLog string
|
||||
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
serverCmd *exec.Cmd
|
||||
serverLogFd *os.File
|
||||
|
||||
configCopyPath string
|
||||
originalConfigPath string
|
||||
originalConfigHash string
|
||||
|
||||
configRaw runConfig
|
||||
apiKey string
|
||||
adminKey string
|
||||
adminJWT string
|
||||
accountID string
|
||||
|
||||
warnings []string
|
||||
results []caseResult
|
||||
}
|
||||
|
||||
type runConfig struct {
|
||||
Keys []string `json:"keys"`
|
||||
Accounts []struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
} `json:"accounts"`
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
r, err := newRunner(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
_ = r.stopServer()
|
||||
}()
|
||||
|
||||
if err := r.prepareRunDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !r.opts.NoPreflight {
|
||||
if err := r.runPreflight(ctx); err != nil {
|
||||
_ = r.writeSummary(start, time.Now())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.prepareConfigIsolation(); err != nil {
|
||||
_ = r.writeSummary(start, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.startServer(ctx); err != nil {
|
||||
_ = r.writeSummary(start, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.prepareAuth(ctx); err != nil {
|
||||
r.warnings = append(r.warnings, "auth prepare failed: "+err.Error())
|
||||
}
|
||||
|
||||
for _, c := range r.cases() {
|
||||
r.runCase(ctx, c)
|
||||
}
|
||||
|
||||
if err := r.ensureOriginalConfigUntouched(); err != nil {
|
||||
r.warnings = append(r.warnings, err.Error())
|
||||
}
|
||||
|
||||
end := time.Now()
|
||||
if err := r.writeSummary(start, end); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prune old test runs, keeping only the most recent N.
|
||||
if err := r.pruneOldRuns(); err != nil {
|
||||
r.warnings = append(r.warnings, "prune old runs: "+err.Error())
|
||||
}
|
||||
|
||||
failed := 0
|
||||
for _, cs := range r.results {
|
||||
if !cs.Passed {
|
||||
failed++
|
||||
}
|
||||
}
|
||||
if failed > 0 {
|
||||
return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newRunner(opts Options) (*Runner, error) {
|
||||
if strings.TrimSpace(opts.ConfigPath) == "" {
|
||||
opts.ConfigPath = "config.json"
|
||||
}
|
||||
if strings.TrimSpace(opts.OutputDir) == "" {
|
||||
opts.OutputDir = "artifacts/testsuite"
|
||||
}
|
||||
if opts.Timeout <= 0 {
|
||||
opts.Timeout = 120 * time.Second
|
||||
}
|
||||
if opts.Retries < 0 {
|
||||
opts.Retries = 0
|
||||
}
|
||||
adminKey := strings.TrimSpace(opts.AdminKey)
|
||||
if adminKey == "" {
|
||||
adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY"))
|
||||
}
|
||||
if adminKey == "" {
|
||||
adminKey = "admin"
|
||||
}
|
||||
opts.AdminKey = adminKey
|
||||
|
||||
return &Runner{
|
||||
opts: opts,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 0,
|
||||
},
|
||||
runID: time.Now().UTC().Format("20060102T150405Z"),
|
||||
adminKey: adminKey,
|
||||
}, nil
|
||||
}
|
||||
func (r *Runner) runCase(ctx context.Context, c caseDef) {
|
||||
caseDir := filepath.Join(r.runDir, "cases", c.ID)
|
||||
_ = os.MkdirAll(caseDir, 0o755)
|
||||
cc := &caseContext{
|
||||
runner: r,
|
||||
id: c.ID,
|
||||
dir: caseDir,
|
||||
startedAt: time.Now(),
|
||||
traceIDsSet: map[string]struct{}{},
|
||||
}
|
||||
err := c.Run(ctx, cc)
|
||||
duration := time.Since(cc.startedAt).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
cc.assertions = append(cc.assertions, assertionResult{
|
||||
Name: "case_error",
|
||||
Passed: false,
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
passed := err == nil
|
||||
for _, a := range cc.assertions {
|
||||
if !a.Passed {
|
||||
passed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
traceIDs := make([]string, 0, len(cc.traceIDsSet))
|
||||
for t := range cc.traceIDsSet {
|
||||
traceIDs = append(traceIDs, t)
|
||||
}
|
||||
sort.Strings(traceIDs)
|
||||
statuses := uniqueStatusCodes(cc.responses)
|
||||
cs := caseResult{
|
||||
CaseID: c.ID,
|
||||
Passed: passed,
|
||||
DurationMS: duration,
|
||||
TraceIDs: traceIDs,
|
||||
StatusCodes: statuses,
|
||||
ArtifactPath: caseDir,
|
||||
Assertions: cc.assertions,
|
||||
}
|
||||
if err != nil {
|
||||
cs.Error = err.Error()
|
||||
}
|
||||
_ = cc.flushArtifacts(cs)
|
||||
r.results = append(r.results, cs)
|
||||
}
|
||||
20
internal/testsuite/runner_defaults.go
Normal file
20
internal/testsuite/runner_defaults.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func DefaultOptions() Options {
|
||||
return Options{
|
||||
ConfigPath: "config.json",
|
||||
AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")),
|
||||
OutputDir: "artifacts/testsuite",
|
||||
Port: 0,
|
||||
Timeout: 120 * time.Second,
|
||||
Retries: 2,
|
||||
NoPreflight: false,
|
||||
MaxKeepRuns: 5,
|
||||
}
|
||||
}
|
||||
261
internal/testsuite/runner_env.go
Normal file
261
internal/testsuite/runner_env.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *Runner) prepareRunDir() error {
|
||||
r.runDir = filepath.Join(r.opts.OutputDir, r.runID)
|
||||
if err := os.MkdirAll(r.runDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
r.serverLog = filepath.Join(r.runDir, "server.log")
|
||||
r.preflightLog = filepath.Join(r.runDir, "preflight.log")
|
||||
return nil
|
||||
}
|
||||
|
||||
// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns.
|
||||
// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order.
|
||||
func (r *Runner) pruneOldRuns() error {
|
||||
keep := r.opts.MaxKeepRuns
|
||||
if keep <= 0 {
|
||||
return nil // 0 or negative means no pruning
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(r.opts.OutputDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Collect only directories (each run is a directory).
|
||||
var runDirs []string
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
runDirs = append(runDirs, e.Name())
|
||||
}
|
||||
|
||||
sort.Strings(runDirs)
|
||||
|
||||
if len(runDirs) <= keep {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove oldest runs (those at the beginning of the sorted list).
|
||||
toRemove := runDirs[:len(runDirs)-keep]
|
||||
var errs []string
|
||||
for _, name := range toRemove {
|
||||
dirPath := filepath.Join(r.opts.OutputDir, name)
|
||||
if err := os.RemoveAll(dirPath); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("remove %s: %v", name, err))
|
||||
} else {
|
||||
fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.New(strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) runPreflight(ctx context.Context) error {
|
||||
steps := [][]string{
|
||||
{"go", "test", "./...", "-count=1"},
|
||||
{"node", "--check", "api/chat-stream.js"},
|
||||
{"node", "--check", "api/helpers/stream-tool-sieve.js"},
|
||||
{"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"},
|
||||
{"npm", "run", "build", "--prefix", "webui"},
|
||||
}
|
||||
f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
for _, step := range steps {
|
||||
if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil {
|
||||
return err
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, step[0], step[1:]...)
|
||||
cmd.Stdout = f
|
||||
cmd.Stderr = f
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) prepareConfigIsolation() error {
|
||||
abs, err := filepath.Abs(r.opts.ConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.originalConfigPath = abs
|
||||
raw, err := os.ReadFile(abs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
r.originalConfigHash = hex.EncodeToString(sum[:])
|
||||
|
||||
tmpDir := filepath.Join(r.runDir, "tmp")
|
||||
if err := os.MkdirAll(tmpDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
r.configCopyPath = filepath.Join(tmpDir, "config.json")
|
||||
if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
var cfg runConfig
|
||||
if err := json.Unmarshal(raw, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config failed: %w", err)
|
||||
}
|
||||
r.configRaw = cfg
|
||||
if len(cfg.Keys) > 0 {
|
||||
r.apiKey = strings.TrimSpace(cfg.Keys[0])
|
||||
}
|
||||
for _, acc := range cfg.Accounts {
|
||||
id := strings.TrimSpace(acc.Email)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(acc.Mobile)
|
||||
}
|
||||
if id != "" {
|
||||
r.accountID = id
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) startServer(ctx context.Context) error {
|
||||
port := r.opts.Port
|
||||
if port <= 0 {
|
||||
p, err := findFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
port = p
|
||||
}
|
||||
r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port)
|
||||
|
||||
logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.serverLogFd = logFd
|
||||
cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api")
|
||||
cmd.Stdout = logFd
|
||||
cmd.Stderr = logFd
|
||||
cmd.Env = prepareServerEnv(os.Environ(), map[string]string{
|
||||
"PORT": strconv.Itoa(port),
|
||||
"DS2API_CONFIG_PATH": r.configCopyPath,
|
||||
"DS2API_AUTO_BUILD_WEBUI": "false",
|
||||
"DS2API_CONFIG_JSON": "",
|
||||
"CONFIG_JSON": "",
|
||||
})
|
||||
if err := cmd.Start(); err != nil {
|
||||
_ = logFd.Close()
|
||||
return err
|
||||
}
|
||||
r.serverCmd = cmd
|
||||
|
||||
deadline := time.Now().Add(90 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if r.ping("/healthz") == nil && r.ping("/readyz") == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
return errors.New("server readiness timeout")
|
||||
}
|
||||
|
||||
func (r *Runner) stopServer() error {
|
||||
var errs []string
|
||||
if r.serverCmd != nil && r.serverCmd.Process != nil {
|
||||
_ = r.serverCmd.Process.Signal(os.Interrupt)
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- r.serverCmd.Wait() }()
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
_ = r.serverCmd.Process.Kill()
|
||||
<-done
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
if r.serverLogFd != nil {
|
||||
if err := r.serverLogFd.Close(); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.New(strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) ping(path string) error {
|
||||
resp, err := r.httpClient.Get(r.baseURL + path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status=%d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) prepareAuth(ctx context.Context) error {
|
||||
reqBody := map[string]any{
|
||||
"admin_key": r.adminKey,
|
||||
"expire_hours": 24,
|
||||
}
|
||||
resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(resp.Body, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
token, _ := m["token"].(string)
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return errors.New("empty admin jwt token")
|
||||
}
|
||||
r.adminJWT = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) ensureOriginalConfigUntouched() error {
|
||||
raw, err := os.ReadFile(r.originalConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
current := hex.EncodeToString(sum[:])
|
||||
if current != r.originalConfigHash {
|
||||
return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
217
internal/testsuite/runner_http.go
Normal file
217
internal/testsuite/runner_http.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (cc *caseContext) assert(name string, ok bool, detail string) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cc.assertions = append(cc.assertions, assertionResult{
|
||||
Name: name,
|
||||
Passed: ok,
|
||||
Detail: detail,
|
||||
})
|
||||
}
|
||||
|
||||
func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) {
|
||||
retries := cc.runner.opts.Retries
|
||||
if !spec.Retryable {
|
||||
retries = 0
|
||||
}
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= retries+1; attempt++ {
|
||||
resp, err := cc.requestOnce(ctx, spec, attempt)
|
||||
if err == nil && resp.StatusCode < 500 {
|
||||
return resp, nil
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if resp.StatusCode >= 500 {
|
||||
lastErr = fmt.Errorf("status=%d", resp.StatusCode)
|
||||
}
|
||||
if attempt <= retries {
|
||||
sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) {
|
||||
cc.mu.Lock()
|
||||
cc.seq++
|
||||
seq := cc.seq
|
||||
traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq)
|
||||
cc.traceIDsSet[traceID] = struct{}{}
|
||||
cc.mu.Unlock()
|
||||
|
||||
fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers := map[string]string{}
|
||||
for k, v := range spec.Headers {
|
||||
headers[k] = v
|
||||
}
|
||||
headers["X-Ds2-Test-Trace"] = traceID
|
||||
|
||||
var bodyBytes []byte
|
||||
var bodyAny any
|
||||
if spec.Body != nil {
|
||||
b, err := json.Marshal(spec.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodyBytes = b
|
||||
bodyAny = spec.Body
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
cc.mu.Lock()
|
||||
cc.requests = append(cc.requests, requestLog{
|
||||
Seq: seq,
|
||||
Attempt: attempt,
|
||||
TraceID: traceID,
|
||||
Method: spec.Method,
|
||||
URL: fullURL,
|
||||
Headers: headers,
|
||||
Body: bodyAny,
|
||||
Timestamp: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
cc.mu.Unlock()
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := cc.runner.httpClient.Do(req)
|
||||
if err != nil {
|
||||
cc.mu.Lock()
|
||||
cc.responses = append(cc.responses, responseLog{
|
||||
Seq: seq,
|
||||
Attempt: attempt,
|
||||
TraceID: traceID,
|
||||
StatusCode: 0,
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
NetworkErr: err.Error(),
|
||||
ReceivedAt: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
cc.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
cc.mu.Lock()
|
||||
cc.responses = append(cc.responses, responseLog{
|
||||
Seq: seq,
|
||||
Attempt: attempt,
|
||||
TraceID: traceID,
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header,
|
||||
BodyText: string(body),
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
ReceivedAt: time.Now().Format(time.RFC3339Nano),
|
||||
})
|
||||
|
||||
if spec.Stream {
|
||||
cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL))
|
||||
cc.streamRaw.Write(body)
|
||||
cc.streamRaw.WriteString("\n\n")
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
|
||||
return &responseResult{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header,
|
||||
Body: body,
|
||||
TraceID: traceID,
|
||||
URL: fullURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cc *caseContext) flushArtifacts(cs caseResult) error {
|
||||
requestPath := filepath.Join(cc.dir, "request.json")
|
||||
headersPath := filepath.Join(cc.dir, "response.headers")
|
||||
bodyPath := filepath.Join(cc.dir, "response.body")
|
||||
streamPath := filepath.Join(cc.dir, "stream.raw")
|
||||
assertPath := filepath.Join(cc.dir, "assertions.json")
|
||||
metaPath := filepath.Join(cc.dir, "meta.json")
|
||||
|
||||
if err := writeJSONFile(requestPath, cc.requests); err != nil {
|
||||
return err
|
||||
}
|
||||
respHeaders := make([]map[string]any, 0, len(cc.responses))
|
||||
respBodies := make([]map[string]any, 0, len(cc.responses))
|
||||
for _, r := range cc.responses {
|
||||
respHeaders = append(respHeaders, map[string]any{
|
||||
"seq": r.Seq,
|
||||
"attempt": r.Attempt,
|
||||
"trace_id": r.TraceID,
|
||||
"status_code": r.StatusCode,
|
||||
"headers": r.Headers,
|
||||
})
|
||||
respBodies = append(respBodies, map[string]any{
|
||||
"seq": r.Seq,
|
||||
"attempt": r.Attempt,
|
||||
"trace_id": r.TraceID,
|
||||
"status_code": r.StatusCode,
|
||||
"body_text": r.BodyText,
|
||||
"network_error": r.NetworkErr,
|
||||
"duration_ms": r.DurationMS,
|
||||
})
|
||||
}
|
||||
if err := writeJSONFile(headersPath, respHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeJSONFile(bodyPath, respBodies); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeJSONFile(assertPath, cc.assertions); err != nil {
|
||||
return err
|
||||
}
|
||||
meta := map[string]any{
|
||||
"case_id": cs.CaseID,
|
||||
"trace_id": strings.Join(cs.TraceIDs, ","),
|
||||
"attempt": len(cc.responses),
|
||||
"duration_ms": cs.DurationMS,
|
||||
"status": map[bool]string{true: "passed", false: "failed"}[cs.Passed],
|
||||
"status_codes": cs.StatusCodes,
|
||||
"assertions": cs.Assertions,
|
||||
"artifact_path": cs.ArtifactPath,
|
||||
}
|
||||
return writeJSONFile(metaPath, meta)
|
||||
}
|
||||
func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) {
|
||||
cc := &caseContext{
|
||||
runner: r,
|
||||
id: "auth_prepare",
|
||||
traceIDsSet: map[string]struct{}{},
|
||||
}
|
||||
return cc.request(ctx, requestSpec{
|
||||
Method: method,
|
||||
Path: path,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
Retryable: true,
|
||||
})
|
||||
}
|
||||
43
internal/testsuite/runner_registry.go
Normal file
43
internal/testsuite/runner_registry.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package testsuite
|
||||
|
||||
import "context"
|
||||
|
||||
type caseDef struct {
|
||||
ID string
|
||||
Run func(context.Context, *caseContext) error
|
||||
}
|
||||
|
||||
func (r *Runner) cases() []caseDef {
|
||||
return []caseDef{
|
||||
{ID: "healthz_ok", Run: r.caseHealthz},
|
||||
{ID: "readyz_ok", Run: r.caseReadyz},
|
||||
{ID: "models_openai", Run: r.caseModelsOpenAI},
|
||||
{ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
|
||||
{ID: "models_claude", Run: r.caseModelsClaude},
|
||||
{ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
|
||||
{ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
|
||||
{ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
|
||||
{ID: "chat_stream_basic", Run: r.caseChatStream},
|
||||
{ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
|
||||
{ID: "responses_stream_basic", Run: r.caseResponsesStream},
|
||||
{ID: "embeddings_contract", Run: r.caseEmbeddings},
|
||||
{ID: "reasoner_stream", Run: r.caseReasonerStream},
|
||||
{ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
|
||||
{ID: "toolcall_stream", Run: r.caseToolcallStream},
|
||||
{ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream},
|
||||
{ID: "anthropic_messages_stream", Run: r.caseAnthropicStream},
|
||||
{ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens},
|
||||
{ID: "admin_account_test_single", Run: r.caseAdminAccountTest},
|
||||
{ID: "concurrency_burst", Run: r.caseConcurrencyBurst},
|
||||
{ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit},
|
||||
{ID: "stream_abort_release", Run: r.caseStreamAbortRelease},
|
||||
{ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed},
|
||||
{ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity},
|
||||
{ID: "error_contract_invalid_model", Run: r.caseInvalidModel},
|
||||
{ID: "error_contract_missing_messages", Run: r.caseMissingMessages},
|
||||
{ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized},
|
||||
{ID: "config_write_isolated", Run: r.caseConfigWriteIsolated},
|
||||
{ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount},
|
||||
{ID: "error_contract_invalid_key", Run: r.caseInvalidKey},
|
||||
}
|
||||
}
|
||||
97
internal/testsuite/runner_summary.go
Normal file
97
internal/testsuite/runner_summary.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *Runner) writeSummary(start, end time.Time) error {
|
||||
passed := 0
|
||||
failed := 0
|
||||
for _, cs := range r.results {
|
||||
if cs.Passed {
|
||||
passed++
|
||||
} else {
|
||||
failed++
|
||||
}
|
||||
}
|
||||
summary := runSummary{
|
||||
RunID: r.runID,
|
||||
StartedAt: start.Format(time.RFC3339Nano),
|
||||
EndedAt: end.Format(time.RFC3339Nano),
|
||||
DurationMS: end.Sub(start).Milliseconds(),
|
||||
Stats: map[string]any{
|
||||
"total": len(r.results),
|
||||
"passed": passed,
|
||||
"failed": failed,
|
||||
},
|
||||
Environment: map[string]any{
|
||||
"go_version": runtime.Version(),
|
||||
"os": runtime.GOOS,
|
||||
"arch": runtime.GOARCH,
|
||||
"base_url": r.baseURL,
|
||||
"config_source": r.originalConfigPath,
|
||||
"config_isolated": r.configCopyPath,
|
||||
"server_log": r.serverLog,
|
||||
"preflight_log": r.preflightLog,
|
||||
"retries": r.opts.Retries,
|
||||
"timeout_seconds": int(r.opts.Timeout.Seconds()),
|
||||
},
|
||||
Cases: r.results,
|
||||
Warnings: r.warnings,
|
||||
}
|
||||
if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644)
|
||||
}
|
||||
|
||||
func (r *Runner) summaryMarkdown(s runSummary) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("# DS2API Live Testsuite Summary\n\n")
|
||||
b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n")
|
||||
fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID)
|
||||
fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt)
|
||||
fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt)
|
||||
fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS)
|
||||
fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"])
|
||||
if len(s.Warnings) > 0 {
|
||||
b.WriteString("## Warnings\n\n")
|
||||
for _, w := range s.Warnings {
|
||||
fmt.Fprintf(&b, "- %s\n", w)
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("## Failed Cases\n\n")
|
||||
hasFailed := false
|
||||
for _, c := range s.Cases {
|
||||
if c.Passed {
|
||||
continue
|
||||
}
|
||||
hasFailed = true
|
||||
fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error)
|
||||
if len(c.TraceIDs) > 0 {
|
||||
fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", "))
|
||||
fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log"))
|
||||
}
|
||||
fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath)
|
||||
}
|
||||
if !hasFailed {
|
||||
b.WriteString("- none\n")
|
||||
}
|
||||
b.WriteString("\n## Case Table\n\n")
|
||||
b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n")
|
||||
b.WriteString("|---|---:|---:|---|---|\n")
|
||||
for _, c := range s.Cases {
|
||||
status := "PASS"
|
||||
if !c.Passed {
|
||||
status = "FAIL"
|
||||
}
|
||||
fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
202
internal/testsuite/runner_utils.go
Normal file
202
internal/testsuite/runner_utils.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseSSEFrames(body []byte) ([]map[string]any, bool) {
|
||||
lines := strings.Split(string(body), "\n")
|
||||
frames := make([]map[string]any, 0, len(lines))
|
||||
done := false
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
if payload == "[DONE]" {
|
||||
done = true
|
||||
continue
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(payload), &m); err == nil {
|
||||
frames = append(frames, m)
|
||||
}
|
||||
}
|
||||
return frames, done
|
||||
}
|
||||
|
||||
func parseClaudeStreamEvents(body []byte) []string {
|
||||
events := []string{}
|
||||
seen := map[string]bool{}
|
||||
lines := strings.Split(string(body), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(payload), &m); err != nil {
|
||||
continue
|
||||
}
|
||||
t := asString(m["type"])
|
||||
if t == "" || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
events = append(events, t)
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func extractModelIDs(body []byte) []string {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
out := []string{}
|
||||
data, _ := m["data"].([]any)
|
||||
for _, it := range data {
|
||||
item, _ := it.(map[string]any)
|
||||
id := asString(item["id"])
|
||||
if id != "" {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func withTraceQuery(rawURL, traceID string) (string, error) {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("__trace_id", traceID)
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func writeJSONFile(path string, v any) error {
|
||||
b, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, b, 0o644)
|
||||
}
|
||||
|
||||
func prepareServerEnv(base []string, overrides map[string]string) []string {
|
||||
out := make([]string, 0, len(base)+len(overrides))
|
||||
skip := map[string]struct{}{}
|
||||
for k := range overrides {
|
||||
skip[k] = struct{}{}
|
||||
}
|
||||
for _, e := range base {
|
||||
parts := strings.SplitN(e, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
if _, ok := skip[parts[0]]; ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
for k, v := range overrides {
|
||||
out = append(out, k+"="+v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func findFreePort() (int, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer ln.Close()
|
||||
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return 0, errors.New("failed to detect tcp port")
|
||||
}
|
||||
return addr.Port, nil
|
||||
}
|
||||
|
||||
func uniqueStatusCodes(in []responseLog) []int {
|
||||
set := map[int]struct{}{}
|
||||
for _, it := range in {
|
||||
if it.StatusCode > 0 {
|
||||
set[it.StatusCode] = struct{}{}
|
||||
}
|
||||
}
|
||||
out := make([]int, 0, len(set))
|
||||
for k := range set {
|
||||
out = append(out, k)
|
||||
}
|
||||
sort.Ints(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func has5xx(dist map[int]int) (int, bool) {
|
||||
for k := range dist {
|
||||
if k >= 500 {
|
||||
return k, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func sanitizeID(s string) string {
|
||||
s = strings.ReplaceAll(s, ":", "_")
|
||||
s = strings.ReplaceAll(s, "/", "_")
|
||||
s = strings.ReplaceAll(s, " ", "_")
|
||||
return s
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(x)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
func toInt(v any) int {
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
return int(x)
|
||||
case float32:
|
||||
return int(x)
|
||||
case int:
|
||||
return x
|
||||
case int64:
|
||||
return int(x)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func contains(xs []string, target string) bool {
|
||||
for _, x := range xs {
|
||||
if x == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
138
internal/util/toolcalls_candidates.go
Normal file
138
internal/util/toolcalls_candidates.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
|
||||
|
||||
func buildToolCallCandidates(text string) []string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
candidates := []string{trimmed}
|
||||
|
||||
// fenced code block candidates: ```json ... ```
|
||||
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||
if len(match) >= 2 {
|
||||
candidates = append(candidates, strings.TrimSpace(match[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// best-effort extraction around "tool_calls" key in mixed text payloads.
|
||||
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
||||
|
||||
// best-effort object slice: from first '{' to last '}'
|
||||
first := strings.Index(trimmed, "{")
|
||||
last := strings.LastIndex(trimmed, "}")
|
||||
if first >= 0 && last > first {
|
||||
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
|
||||
}
|
||||
|
||||
// legacy regex extraction fallback
|
||||
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
|
||||
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
|
||||
}
|
||||
|
||||
uniq := make([]string, 0, len(candidates))
|
||||
seen := map[string]struct{}{}
|
||||
for _, c := range candidates {
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
uniq = append(uniq, c)
|
||||
}
|
||||
return uniq
|
||||
}
|
||||
|
||||
func extractToolCallObjects(text string) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
out := []string{}
|
||||
offset := 0
|
||||
for {
|
||||
idx := strings.Index(lower[offset:], "tool_calls")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += offset
|
||||
start := strings.LastIndex(text[:idx], "{")
|
||||
for start >= 0 {
|
||||
candidate, end, ok := extractJSONObject(text, start)
|
||||
if ok {
|
||||
// Move forward to avoid repeatedly matching the same object.
|
||||
offset = end
|
||||
out = append(out, strings.TrimSpace(candidate))
|
||||
break
|
||||
}
|
||||
start = strings.LastIndex(text[:start], "{")
|
||||
}
|
||||
if start < 0 {
|
||||
offset = idx + len("tool_calls")
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return text[start : i+1], i + 1, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
if t == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(t, "```")
|
||||
}
|
||||
|
||||
func stripFencedCodeBlocks(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return ""
|
||||
}
|
||||
return fencedBlockPattern.ReplaceAllString(text, " ")
|
||||
}
|
||||
41
internal/util/toolcalls_format.go
Normal file
41
internal/util/toolcalls_format.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -2,16 +2,9 @@ package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
|
||||
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Input map[string]any `json:"input"`
|
||||
@@ -102,47 +95,6 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
|
||||
return out
|
||||
}
|
||||
|
||||
func buildToolCallCandidates(text string) []string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
candidates := []string{trimmed}
|
||||
|
||||
// fenced code block candidates: ```json ... ```
|
||||
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||
if len(match) >= 2 {
|
||||
candidates = append(candidates, strings.TrimSpace(match[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// best-effort extraction around "tool_calls" key in mixed text payloads.
|
||||
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
||||
|
||||
// best-effort object slice: from first '{' to last '}'
|
||||
first := strings.Index(trimmed, "{")
|
||||
last := strings.LastIndex(trimmed, "}")
|
||||
if first >= 0 && last > first {
|
||||
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
|
||||
}
|
||||
|
||||
// legacy regex extraction fallback
|
||||
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
|
||||
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
|
||||
}
|
||||
|
||||
uniq := make([]string, 0, len(candidates))
|
||||
seen := map[string]struct{}{}
|
||||
for _, c := range candidates {
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
uniq = append(uniq, c)
|
||||
}
|
||||
return uniq
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
var decoded any
|
||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||
@@ -243,123 +195,3 @@ func parseToolCallInput(v any) map[string]any {
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func extractToolCallObjects(text string) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
out := []string{}
|
||||
offset := 0
|
||||
for {
|
||||
idx := strings.Index(lower[offset:], "tool_calls")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += offset
|
||||
start := strings.LastIndex(text[:idx], "{")
|
||||
for start >= 0 {
|
||||
candidate, end, ok := extractJSONObject(text, start)
|
||||
if ok {
|
||||
// Move forward to avoid repeatedly matching the same object.
|
||||
offset = end
|
||||
out = append(out, strings.TrimSpace(candidate))
|
||||
break
|
||||
}
|
||||
start = strings.LastIndex(text[:start], "{")
|
||||
}
|
||||
if start < 0 {
|
||||
offset = idx + len("tool_calls")
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return text[start : i+1], i + 1, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
if t == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(t, "```")
|
||||
}
|
||||
|
||||
func stripFencedCodeBlocks(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return ""
|
||||
}
|
||||
return fencedBlockPattern.ReplaceAllString(text, " ")
|
||||
}
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user