feat: Implement DeepSeek integration, refactor model adapters for streaming and tool calls, enhance admin and account management, and introduce new UI features for settings, API testing, and Vercel sync.

This commit is contained in:
CJACK
2026-02-22 17:25:48 +08:00
parent 5d3989a9a7
commit 6c48429b90
152 changed files with 13583 additions and 11817 deletions

View File

@@ -1,363 +0,0 @@
package account
import (
"context"
"os"
"sort"
"strconv"
"strings"
"sync"
"ds2api/internal/config"
)
type Pool struct {
store *config.Store
mu sync.Mutex
queue []string
inUse map[string]int
waiters []chan struct{}
maxInflightPerAccount int
recommendedConcurrency int
maxQueueSize int
globalMaxInflight int
}
func NewPool(store *config.Store) *Pool {
maxPer := 2
if store != nil {
maxPer = store.RuntimeAccountMaxInflight()
}
p := &Pool{
store: store,
inUse: map[string]int{},
maxInflightPerAccount: maxPer,
}
p.Reset()
return p
}
func (p *Pool) Reset() {
accounts := p.store.Accounts()
sort.SliceStable(accounts, func(i, j int) bool {
iHas := accounts[i].Token != ""
jHas := accounts[j].Token != ""
if iHas == jHas {
return i < j
}
return iHas
})
ids := make([]string, 0, len(accounts))
for _, a := range accounts {
id := a.Identifier()
if id != "" {
ids = append(ids, id)
}
}
if p.store != nil {
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
} else {
p.maxInflightPerAccount = maxInflightFromEnv()
}
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
queueLimit := maxQueueFromEnv(recommended)
globalLimit := recommended
if p.store != nil {
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
}
p.mu.Lock()
defer p.mu.Unlock()
p.drainWaitersLocked()
p.queue = ids
p.inUse = map[string]int{}
p.recommendedConcurrency = recommended
p.maxQueueSize = queueLimit
p.globalMaxInflight = globalLimit
config.Logger.Info(
"[init_account_queue] initialized",
"total", len(ids),
"max_inflight_per_account", p.maxInflightPerAccount,
"global_max_inflight", p.globalMaxInflight,
"recommended_concurrency", p.recommendedConcurrency,
"max_queue_size", p.maxQueueSize,
)
}
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
p.mu.Lock()
defer p.mu.Unlock()
return p.acquireLocked(target, normalizeExclude(exclude))
}
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
if ctx == nil {
ctx = context.Background()
}
exclude = normalizeExclude(exclude)
for {
if ctx.Err() != nil {
return config.Account{}, false
}
p.mu.Lock()
if acc, ok := p.acquireLocked(target, exclude); ok {
p.mu.Unlock()
return acc, true
}
if !p.canQueueLocked(target, exclude) {
p.mu.Unlock()
return config.Account{}, false
}
waiter := make(chan struct{})
p.waiters = append(p.waiters, waiter)
p.mu.Unlock()
select {
case <-ctx.Done():
p.mu.Lock()
p.removeWaiterLocked(waiter)
p.mu.Unlock()
return config.Account{}, false
case <-waiter:
}
}
}
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
if target != "" {
if exclude[target] || !p.canAcquireIDLocked(target) {
return config.Account{}, false
}
acc, ok := p.store.FindAccount(target)
if !ok {
return config.Account{}, false
}
p.inUse[target]++
p.bumpQueue(target)
return acc, true
}
if acc, ok := p.tryAcquire(exclude, true); ok {
return acc, true
}
if acc, ok := p.tryAcquire(exclude, false); ok {
return acc, true
}
return config.Account{}, false
}
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
for i := 0; i < len(p.queue); i++ {
id := p.queue[i]
if exclude[id] || !p.canAcquireIDLocked(id) {
continue
}
acc, ok := p.store.FindAccount(id)
if !ok {
continue
}
if requireToken && acc.Token == "" {
continue
}
p.inUse[id]++
p.bumpQueue(id)
return acc, true
}
return config.Account{}, false
}
func (p *Pool) bumpQueue(accountID string) {
for i, id := range p.queue {
if id != accountID {
continue
}
p.queue = append(p.queue[:i], p.queue[i+1:]...)
p.queue = append(p.queue, accountID)
return
}
}
func (p *Pool) Release(accountID string) {
if accountID == "" {
return
}
p.mu.Lock()
defer p.mu.Unlock()
count := p.inUse[accountID]
if count <= 0 {
return
}
if count == 1 {
delete(p.inUse, accountID)
p.notifyWaiterLocked()
return
}
p.inUse[accountID] = count - 1
p.notifyWaiterLocked()
}
func (p *Pool) Status() map[string]any {
p.mu.Lock()
defer p.mu.Unlock()
available := make([]string, 0, len(p.queue))
inUseAccounts := make([]string, 0, len(p.inUse))
inUseSlots := 0
for _, id := range p.queue {
if p.inUse[id] < p.maxInflightPerAccount {
available = append(available, id)
}
}
for id, count := range p.inUse {
if count > 0 {
inUseAccounts = append(inUseAccounts, id)
inUseSlots += count
}
}
sort.Strings(inUseAccounts)
return map[string]any{
"available": len(available),
"in_use": inUseSlots,
"total": len(p.store.Accounts()),
"available_accounts": available,
"in_use_accounts": inUseAccounts,
"max_inflight_per_account": p.maxInflightPerAccount,
"global_max_inflight": p.globalMaxInflight,
"recommended_concurrency": p.recommendedConcurrency,
"waiting": len(p.waiters),
"max_queue_size": p.maxQueueSize,
}
}
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
if maxInflightPerAccount <= 0 {
maxInflightPerAccount = 1
}
if maxQueueSize < 0 {
maxQueueSize = 0
}
if globalMaxInflight <= 0 {
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
if globalMaxInflight <= 0 {
globalMaxInflight = maxInflightPerAccount
}
}
p.mu.Lock()
defer p.mu.Unlock()
p.maxInflightPerAccount = maxInflightPerAccount
p.maxQueueSize = maxQueueSize
p.globalMaxInflight = globalMaxInflight
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
p.notifyWaiterLocked()
}
func maxInflightFromEnv() int {
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n > 0 {
return n
}
}
return 2
}
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
if accountCount <= 0 {
return 0
}
if maxInflightPerAccount <= 0 {
maxInflightPerAccount = 2
}
return accountCount * maxInflightPerAccount
}
func normalizeExclude(exclude map[string]bool) map[string]bool {
if exclude == nil {
return map[string]bool{}
}
return exclude
}
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
if target != "" {
if exclude[target] {
return false
}
if _, ok := p.store.FindAccount(target); !ok {
return false
}
}
if p.maxQueueSize <= 0 {
return false
}
return len(p.waiters) < p.maxQueueSize
}
func (p *Pool) notifyWaiterLocked() {
if len(p.waiters) == 0 {
return
}
waiter := p.waiters[0]
p.waiters = p.waiters[1:]
close(waiter)
}
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
for i, w := range p.waiters {
if w != waiter {
continue
}
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
return true
}
return false
}
func (p *Pool) drainWaitersLocked() {
for _, waiter := range p.waiters {
close(waiter)
}
p.waiters = nil
}
func maxQueueFromEnv(defaultSize int) int {
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n >= 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}
func (p *Pool) canAcquireIDLocked(accountID string) bool {
if accountID == "" {
return false
}
if p.inUse[accountID] >= p.maxInflightPerAccount {
return false
}
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
return false
}
return true
}
func (p *Pool) currentInUseLocked() int {
total := 0
for _, n := range p.inUse {
total += n
}
return total
}

View File

@@ -0,0 +1,108 @@
package account
import (
"context"
"ds2api/internal/config"
)
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
p.mu.Lock()
defer p.mu.Unlock()
return p.acquireLocked(target, normalizeExclude(exclude))
}
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
if ctx == nil {
ctx = context.Background()
}
exclude = normalizeExclude(exclude)
for {
if ctx.Err() != nil {
return config.Account{}, false
}
p.mu.Lock()
if acc, ok := p.acquireLocked(target, exclude); ok {
p.mu.Unlock()
return acc, true
}
if !p.canQueueLocked(target, exclude) {
p.mu.Unlock()
return config.Account{}, false
}
waiter := make(chan struct{})
p.waiters = append(p.waiters, waiter)
p.mu.Unlock()
select {
case <-ctx.Done():
p.mu.Lock()
p.removeWaiterLocked(waiter)
p.mu.Unlock()
return config.Account{}, false
case <-waiter:
}
}
}
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
if target != "" {
if exclude[target] || !p.canAcquireIDLocked(target) {
return config.Account{}, false
}
acc, ok := p.store.FindAccount(target)
if !ok {
return config.Account{}, false
}
p.inUse[target]++
p.bumpQueue(target)
return acc, true
}
if acc, ok := p.tryAcquire(exclude, true); ok {
return acc, true
}
if acc, ok := p.tryAcquire(exclude, false); ok {
return acc, true
}
return config.Account{}, false
}
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
for i := 0; i < len(p.queue); i++ {
id := p.queue[i]
if exclude[id] || !p.canAcquireIDLocked(id) {
continue
}
acc, ok := p.store.FindAccount(id)
if !ok {
continue
}
if requireToken && acc.Token == "" {
continue
}
p.inUse[id]++
p.bumpQueue(id)
return acc, true
}
return config.Account{}, false
}
func (p *Pool) bumpQueue(accountID string) {
for i, id := range p.queue {
if id != accountID {
continue
}
p.queue = append(p.queue[:i], p.queue[i+1:]...)
p.queue = append(p.queue, accountID)
return
}
}
func normalizeExclude(exclude map[string]bool) map[string]bool {
if exclude == nil {
return map[string]bool{}
}
return exclude
}

View File

@@ -0,0 +1,132 @@
package account
import (
"sort"
"sync"
"ds2api/internal/config"
)
type Pool struct {
store *config.Store
mu sync.Mutex
queue []string
inUse map[string]int
waiters []chan struct{}
maxInflightPerAccount int
recommendedConcurrency int
maxQueueSize int
globalMaxInflight int
}
func NewPool(store *config.Store) *Pool {
maxPer := 2
if store != nil {
maxPer = store.RuntimeAccountMaxInflight()
}
p := &Pool{
store: store,
inUse: map[string]int{},
maxInflightPerAccount: maxPer,
}
p.Reset()
return p
}
func (p *Pool) Reset() {
accounts := p.store.Accounts()
sort.SliceStable(accounts, func(i, j int) bool {
iHas := accounts[i].Token != ""
jHas := accounts[j].Token != ""
if iHas == jHas {
return i < j
}
return iHas
})
ids := make([]string, 0, len(accounts))
for _, a := range accounts {
id := a.Identifier()
if id != "" {
ids = append(ids, id)
}
}
if p.store != nil {
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
} else {
p.maxInflightPerAccount = maxInflightFromEnv()
}
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
queueLimit := maxQueueFromEnv(recommended)
globalLimit := recommended
if p.store != nil {
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
}
p.mu.Lock()
defer p.mu.Unlock()
p.drainWaitersLocked()
p.queue = ids
p.inUse = map[string]int{}
p.recommendedConcurrency = recommended
p.maxQueueSize = queueLimit
p.globalMaxInflight = globalLimit
config.Logger.Info(
"[init_account_queue] initialized",
"total", len(ids),
"max_inflight_per_account", p.maxInflightPerAccount,
"global_max_inflight", p.globalMaxInflight,
"recommended_concurrency", p.recommendedConcurrency,
"max_queue_size", p.maxQueueSize,
)
}
func (p *Pool) Release(accountID string) {
if accountID == "" {
return
}
p.mu.Lock()
defer p.mu.Unlock()
count := p.inUse[accountID]
if count <= 0 {
return
}
if count == 1 {
delete(p.inUse, accountID)
p.notifyWaiterLocked()
return
}
p.inUse[accountID] = count - 1
p.notifyWaiterLocked()
}
func (p *Pool) Status() map[string]any {
p.mu.Lock()
defer p.mu.Unlock()
available := make([]string, 0, len(p.queue))
inUseAccounts := make([]string, 0, len(p.inUse))
inUseSlots := 0
for _, id := range p.queue {
if p.inUse[id] < p.maxInflightPerAccount {
available = append(available, id)
}
}
for id, count := range p.inUse {
if count > 0 {
inUseAccounts = append(inUseAccounts, id)
inUseSlots += count
}
}
sort.Strings(inUseAccounts)
return map[string]any{
"available": len(available),
"in_use": inUseSlots,
"total": len(p.store.Accounts()),
"available_accounts": available,
"in_use_accounts": inUseAccounts,
"max_inflight_per_account": p.maxInflightPerAccount,
"global_max_inflight": p.globalMaxInflight,
"recommended_concurrency": p.recommendedConcurrency,
"waiting": len(p.waiters),
"max_queue_size": p.maxQueueSize,
}
}

View File

@@ -0,0 +1,91 @@
package account
import (
"os"
"strconv"
"strings"
)
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
if maxInflightPerAccount <= 0 {
maxInflightPerAccount = 1
}
if maxQueueSize < 0 {
maxQueueSize = 0
}
if globalMaxInflight <= 0 {
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
if globalMaxInflight <= 0 {
globalMaxInflight = maxInflightPerAccount
}
}
p.mu.Lock()
defer p.mu.Unlock()
p.maxInflightPerAccount = maxInflightPerAccount
p.maxQueueSize = maxQueueSize
p.globalMaxInflight = globalMaxInflight
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
p.notifyWaiterLocked()
}
func maxInflightFromEnv() int {
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n > 0 {
return n
}
}
return 2
}
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
if accountCount <= 0 {
return 0
}
if maxInflightPerAccount <= 0 {
maxInflightPerAccount = 2
}
return accountCount * maxInflightPerAccount
}
func maxQueueFromEnv(defaultSize int) int {
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n >= 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}
func (p *Pool) canAcquireIDLocked(accountID string) bool {
if accountID == "" {
return false
}
if p.inUse[accountID] >= p.maxInflightPerAccount {
return false
}
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
return false
}
return true
}
func (p *Pool) currentInUseLocked() int {
total := 0
for _, n := range p.inUse {
total += n
}
return total
}

View File

@@ -0,0 +1,43 @@
package account
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
if target != "" {
if exclude[target] {
return false
}
if _, ok := p.store.FindAccount(target); !ok {
return false
}
}
if p.maxQueueSize <= 0 {
return false
}
return len(p.waiters) < p.maxQueueSize
}
func (p *Pool) notifyWaiterLocked() {
if len(p.waiters) == 0 {
return
}
waiter := p.waiters[0]
p.waiters = p.waiters[1:]
close(waiter)
}
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
for i, w := range p.waiters {
if w != waiter {
continue
}
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
return true
}
return false
}
func (p *Pool) drainWaitersLocked() {
for _, waiter := range p.waiters {
close(waiter)
}
p.waiters = nil
}

View File

@@ -32,4 +32,3 @@ func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
t.Fatal("expected param field")
}
}

View File

@@ -1,368 +0,0 @@
package claude
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
claudefmt "ds2api/internal/format/claude"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
}
var (
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
)
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/anthropic/v1/models", h.ListModels)
r.Post("/anthropic/v1/messages", h.Messages)
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
r.Post("/v1/messages", h.Messages)
r.Post("/messages", h.Messages)
r.Post("/v1/messages/count_tokens", h.CountTokens)
r.Post("/messages/count_tokens", h.CountTokens)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
}
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
r.Header.Set("anthropic-version", "2023-06-01")
}
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeClaudeError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
norm, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
stdReq := norm.Standard
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
return
}
requestPayload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
if err != nil {
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
return
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
if stdReq.Stream {
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
result := sse.CollectStream(resp, stdReq.Thinking, true)
respBody := claudefmt.BuildMessageResponse(
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
stdReq.ResponseModel,
norm.NormalizedMessages,
result.Thinking,
result.Text,
stdReq.ToolNames,
)
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, err.Error())
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messages, _ := req["messages"].([]any)
if model == "" || len(messages) == 0 {
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
if sys, ok := req["system"].(string); ok {
inputTokens += util.EstimateTokens(sys)
}
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
inputTokens += 2
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
}
if tools, ok := req["tools"].([]any); ok {
for _, t := range tools {
b, _ := json.Marshal(t)
inputTokens += util.EstimateTokens(string(b))
}
}
if inputTokens < 1 {
inputTokens = 1
}
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
}
streamRuntime := newClaudeStreamRuntime(
w,
rc,
canFlush,
model,
messages,
thinkingEnabled,
searchEnabled,
toolNames,
)
streamRuntime.sendMessageStart()
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: claudeStreamPingInterval,
IdleTimeout: claudeStreamIdleTimeout,
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
}, streamengine.ConsumeHooks{
OnKeepAlive: func() {
streamRuntime.sendPing()
},
OnParsed: streamRuntime.onParsed,
OnFinalize: streamRuntime.onFinalize,
})
}
func writeClaudeError(w http.ResponseWriter, status int, message string) {
code := "invalid_request"
switch status {
case http.StatusUnauthorized:
code = "authentication_failed"
case http.StatusTooManyRequests:
code = "rate_limit_exceeded"
case http.StatusNotFound:
code = "not_found"
case http.StatusInternalServerError:
code = "internal_error"
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"type": "invalid_request_error",
"message": message,
"code": code,
"param": nil,
},
})
}
func normalizeClaudeMessages(messages []any) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
msg, ok := m.(map[string]any)
if !ok {
continue
}
copied := cloneMap(msg)
switch content := msg["content"].(type) {
case []any:
parts := make([]string, 0, len(content))
for _, block := range content {
b, ok := block.(map[string]any)
if !ok {
continue
}
typeStr, _ := b["type"].(string)
if typeStr == "text" {
if t, ok := b["text"].(string); ok {
parts = append(parts, t)
}
}
if typeStr == "tool_result" {
parts = append(parts, formatClaudeToolResultForPrompt(b))
}
}
copied["content"] = strings.Join(parts, "\n")
}
out = append(out, copied)
}
return out
}
func buildClaudeToolPrompt(tools []any) string {
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
name, _ := m["name"].(string)
desc, _ := m["description"].(string)
schema, _ := json.Marshal(m["input_schema"])
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
}
parts = append(parts,
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
)
return strings.Join(parts, "\n\n")
}
func formatClaudeToolResultForPrompt(block map[string]any) string {
if block == nil {
return ""
}
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
if toolCallID == "" {
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
}
if toolCallID == "" {
toolCallID = "unknown"
}
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
if name == "" {
name = "unknown"
}
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
if content == "" {
content = "null"
}
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
}
func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)
if ok && msg["role"] == "system" {
return true
}
}
return false
}
func extractClaudeToolNames(tools []any) []string {
out := make([]string, 0, len(tools))
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
if name, ok := m["name"].(string); ok && name != "" {
out = append(out, name)
}
}
return out
}
func toMessageMaps(v any) []map[string]any {
arr, ok := v.([]any)
if !ok {
return nil
}
out := make([]map[string]any, 0, len(arr))
for _, item := range arr {
if m, ok := item.(map[string]any); ok {
out = append(out, m)
}
}
return out
}
func extractMessageContent(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, it := range x {
parts = append(parts, fmt.Sprintf("%v", it))
}
return strings.Join(parts, "\n")
default:
return fmt.Sprintf("%v", x)
}
}
func cloneMap(in map[string]any) map[string]any {
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}

View File

@@ -0,0 +1,25 @@
package claude
import "net/http"
func writeClaudeError(w http.ResponseWriter, status int, message string) {
code := "invalid_request"
switch status {
case http.StatusUnauthorized:
code = "authentication_failed"
case http.StatusTooManyRequests:
code = "rate_limit_exceeded"
case http.StatusNotFound:
code = "not_found"
case http.StatusInternalServerError:
code = "internal_error"
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"type": "invalid_request_error",
"message": message,
"code": code,
"param": nil,
},
})
}

View File

@@ -0,0 +1,134 @@
package claude
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"ds2api/internal/auth"
"ds2api/internal/config"
claudefmt "ds2api/internal/format/claude"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
r.Header.Set("anthropic-version", "2023-06-01")
}
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeClaudeError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
norm, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
stdReq := norm.Standard
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
return
}
requestPayload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
if err != nil {
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
return
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
if stdReq.Stream {
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
result := sse.CollectStream(resp, stdReq.Thinking, true)
respBody := claudefmt.BuildMessageResponse(
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
stdReq.ResponseModel,
norm.NormalizedMessages,
result.Thinking,
result.Text,
stdReq.ToolNames,
)
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
}
streamRuntime := newClaudeStreamRuntime(
w,
rc,
canFlush,
model,
messages,
thinkingEnabled,
searchEnabled,
toolNames,
)
streamRuntime.sendMessageStart()
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: claudeStreamPingInterval,
IdleTimeout: claudeStreamIdleTimeout,
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
}, streamengine.ConsumeHooks{
OnKeepAlive: func() {
streamRuntime.sendPing()
},
OnParsed: streamRuntime.onParsed,
OnFinalize: streamRuntime.onFinalize,
})
}

View File

@@ -0,0 +1,41 @@
package claude
import (
"net/http"
"time"
"github.com/go-chi/chi/v5"
"ds2api/internal/config"
"ds2api/internal/deepseek"
"ds2api/internal/util"
)
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
}
var (
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
)
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/anthropic/v1/models", h.ListModels)
r.Post("/anthropic/v1/messages", h.Messages)
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
r.Post("/v1/messages", h.Messages)
r.Post("/messages", h.Messages)
r.Post("/v1/messages/count_tokens", h.CountTokens)
r.Post("/messages/count_tokens", h.CountTokens)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
}

View File

@@ -0,0 +1,51 @@
package claude
import (
"encoding/json"
"net/http"
"ds2api/internal/util"
)
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
writeClaudeError(w, http.StatusUnauthorized, err.Error())
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messages, _ := req["messages"].([]any)
if model == "" || len(messages) == 0 {
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
if sys, ok := req["system"].(string); ok {
inputTokens += util.EstimateTokens(sys)
}
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
inputTokens += 2
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
}
if tools, ok := req["tools"].([]any); ok {
for _, t := range tools {
b, _ := json.Marshal(t)
inputTokens += util.EstimateTokens(string(b))
}
}
if inputTokens < 1 {
inputTokens = 1
}
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}

View File

@@ -0,0 +1,143 @@
package claude
import (
"encoding/json"
"fmt"
"strings"
)
func normalizeClaudeMessages(messages []any) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
msg, ok := m.(map[string]any)
if !ok {
continue
}
copied := cloneMap(msg)
switch content := msg["content"].(type) {
case []any:
parts := make([]string, 0, len(content))
for _, block := range content {
b, ok := block.(map[string]any)
if !ok {
continue
}
typeStr, _ := b["type"].(string)
if typeStr == "text" {
if t, ok := b["text"].(string); ok {
parts = append(parts, t)
}
}
if typeStr == "tool_result" {
parts = append(parts, formatClaudeToolResultForPrompt(b))
}
}
copied["content"] = strings.Join(parts, "\n")
}
out = append(out, copied)
}
return out
}
func buildClaudeToolPrompt(tools []any) string {
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
name, _ := m["name"].(string)
desc, _ := m["description"].(string)
schema, _ := json.Marshal(m["input_schema"])
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
}
parts = append(parts,
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
)
return strings.Join(parts, "\n\n")
}
func formatClaudeToolResultForPrompt(block map[string]any) string {
if block == nil {
return ""
}
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
if toolCallID == "" {
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
}
if toolCallID == "" {
toolCallID = "unknown"
}
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
if name == "" {
name = "unknown"
}
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
if content == "" {
content = "null"
}
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
}
func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)
if ok && msg["role"] == "system" {
return true
}
}
return false
}
func extractClaudeToolNames(tools []any) []string {
out := make([]string, 0, len(tools))
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
if name, ok := m["name"].(string); ok && name != "" {
out = append(out, name)
}
}
return out
}
func toMessageMaps(v any) []map[string]any {
arr, ok := v.([]any)
if !ok {
return nil
}
out := make([]map[string]any, 0, len(arr))
for _, item := range arr {
if m, ok := item.(map[string]any); ok {
out = append(out, m)
}
}
return out
}
func extractMessageContent(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, it := range x {
parts = append(parts, fmt.Sprintf("%v", it))
}
return strings.Join(parts, "\n")
default:
return fmt.Sprintf("%v", x)
}
}
func cloneMap(in map[string]any) map[string]any {
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}

View File

@@ -1,308 +0,0 @@
package claude
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
type claudeStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
model string
toolNames []string
messages []any
thinkingEnabled bool
searchEnabled bool
bufferToolContent bool
messageID string
thinking strings.Builder
text strings.Builder
nextBlockIndex int
thinkingBlockOpen bool
thinkingBlockIndex int
textBlockOpen bool
textBlockIndex int
ended bool
upstreamErr string
}
func newClaudeStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
model string,
messages []any,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
) *claudeStreamRuntime {
return &claudeStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
model: model,
messages: messages,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
bufferToolContent: len(toolNames) > 0,
toolNames: toolNames,
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
thinkingBlockIndex: -1,
textBlockIndex: -1,
}
}
func (s *claudeStreamRuntime) send(event string, v any) {
b, _ := json.Marshal(v)
_, _ = s.w.Write([]byte("event: "))
_, _ = s.w.Write([]byte(event))
_, _ = s.w.Write([]byte("\n"))
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *claudeStreamRuntime) sendError(message string) {
msg := strings.TrimSpace(message)
if msg == "" {
msg = "upstream stream error"
}
s.send("error", map[string]any{
"type": "error",
"error": map[string]any{
"type": "api_error",
"message": msg,
"code": "internal_error",
"param": nil,
},
})
}
func (s *claudeStreamRuntime) sendPing() {
s.send("ping", map[string]any{"type": "ping"})
}
func (s *claudeStreamRuntime) sendMessageStart() {
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
s.send("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{
"id": s.messageID,
"type": "message",
"role": "assistant",
"model": s.model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
},
})
}
func (s *claudeStreamRuntime) closeThinkingBlock() {
if !s.thinkingBlockOpen {
return
}
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": s.thinkingBlockIndex,
})
s.thinkingBlockOpen = false
s.thinkingBlockIndex = -1
}
func (s *claudeStreamRuntime) closeTextBlock() {
if !s.textBlockOpen {
return
}
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": s.textBlockIndex,
})
s.textBlockOpen = false
s.textBlockIndex = -1
}
func (s *claudeStreamRuntime) finalize(stopReason string) {
if s.ended {
return
}
s.ended = true
s.closeThinkingBlock()
s.closeTextBlock()
finalThinking := s.thinking.String()
finalText := s.text.String()
if s.bufferToolContent {
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) > 0 {
stopReason = "tool_use"
for i, tc := range detected {
idx := s.nextBlockIndex + i
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
"name": tc.Name,
"input": tc.Input,
},
})
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": idx,
})
}
s.nextBlockIndex += len(detected)
} else if finalText != "" {
idx := s.nextBlockIndex
s.nextBlockIndex++
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": idx,
"delta": map[string]any{
"type": "text_delta",
"text": finalText,
},
})
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": idx,
})
}
}
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
s.send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": map[string]any{
"output_tokens": outputTokens,
},
})
s.send("message_stop", map[string]any{"type": "message_stop"})
}
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ErrorMessage != "" {
s.upstreamErr = parsed.ErrorMessage
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
}
if parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if !s.thinkingEnabled {
continue
}
s.thinking.WriteString(p.Text)
s.closeTextBlock()
if !s.thinkingBlockOpen {
s.thinkingBlockIndex = s.nextBlockIndex
s.nextBlockIndex++
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": s.thinkingBlockIndex,
"content_block": map[string]any{
"type": "thinking",
"thinking": "",
},
})
s.thinkingBlockOpen = true
}
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": s.thinkingBlockIndex,
"delta": map[string]any{
"type": "thinking_delta",
"thinking": p.Text,
},
})
continue
}
s.text.WriteString(p.Text)
if s.bufferToolContent {
continue
}
s.closeThinkingBlock()
if !s.textBlockOpen {
s.textBlockIndex = s.nextBlockIndex
s.nextBlockIndex++
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": s.textBlockIndex,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
s.textBlockOpen = true
}
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": s.textBlockIndex,
"delta": map[string]any{
"type": "text_delta",
"text": p.Text,
},
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
if string(reason) == "upstream_error" {
s.sendError(s.upstreamErr)
return
}
if scannerErr != nil {
s.sendError(scannerErr.Error())
return
}
s.finalize("end_turn")
}

View File

@@ -0,0 +1,146 @@
package claude
import (
"fmt"
"net/http"
"strings"
"time"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
type claudeStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
model string
toolNames []string
messages []any
thinkingEnabled bool
searchEnabled bool
bufferToolContent bool
messageID string
thinking strings.Builder
text strings.Builder
nextBlockIndex int
thinkingBlockOpen bool
thinkingBlockIndex int
textBlockOpen bool
textBlockIndex int
ended bool
upstreamErr string
}
func newClaudeStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
model string,
messages []any,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
) *claudeStreamRuntime {
return &claudeStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
model: model,
messages: messages,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
bufferToolContent: len(toolNames) > 0,
toolNames: toolNames,
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
thinkingBlockIndex: -1,
textBlockIndex: -1,
}
}
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ErrorMessage != "" {
s.upstreamErr = parsed.ErrorMessage
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
}
if parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if !s.thinkingEnabled {
continue
}
s.thinking.WriteString(p.Text)
s.closeTextBlock()
if !s.thinkingBlockOpen {
s.thinkingBlockIndex = s.nextBlockIndex
s.nextBlockIndex++
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": s.thinkingBlockIndex,
"content_block": map[string]any{
"type": "thinking",
"thinking": "",
},
})
s.thinkingBlockOpen = true
}
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": s.thinkingBlockIndex,
"delta": map[string]any{
"type": "thinking_delta",
"thinking": p.Text,
},
})
continue
}
s.text.WriteString(p.Text)
if s.bufferToolContent {
continue
}
s.closeThinkingBlock()
if !s.textBlockOpen {
s.textBlockIndex = s.nextBlockIndex
s.nextBlockIndex++
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": s.textBlockIndex,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
s.textBlockOpen = true
}
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": s.textBlockIndex,
"delta": map[string]any{
"type": "text_delta",
"text": p.Text,
},
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}

View File

@@ -0,0 +1,59 @@
package claude
import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/util"
)
func (s *claudeStreamRuntime) send(event string, v any) {
b, _ := json.Marshal(v)
_, _ = s.w.Write([]byte("event: "))
_, _ = s.w.Write([]byte(event))
_, _ = s.w.Write([]byte("\n"))
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *claudeStreamRuntime) sendError(message string) {
msg := strings.TrimSpace(message)
if msg == "" {
msg = "upstream stream error"
}
s.send("error", map[string]any{
"type": "error",
"error": map[string]any{
"type": "api_error",
"message": msg,
"code": "internal_error",
"param": nil,
},
})
}
func (s *claudeStreamRuntime) sendPing() {
s.send("ping", map[string]any{"type": "ping"})
}
func (s *claudeStreamRuntime) sendMessageStart() {
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
s.send("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{
"id": s.messageID,
"type": "message",
"role": "assistant",
"model": s.model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
},
})
}

View File

@@ -0,0 +1,119 @@
package claude
import (
"fmt"
"time"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
func (s *claudeStreamRuntime) closeThinkingBlock() {
if !s.thinkingBlockOpen {
return
}
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": s.thinkingBlockIndex,
})
s.thinkingBlockOpen = false
s.thinkingBlockIndex = -1
}
func (s *claudeStreamRuntime) closeTextBlock() {
if !s.textBlockOpen {
return
}
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": s.textBlockIndex,
})
s.textBlockOpen = false
s.textBlockIndex = -1
}
func (s *claudeStreamRuntime) finalize(stopReason string) {
if s.ended {
return
}
s.ended = true
s.closeThinkingBlock()
s.closeTextBlock()
finalThinking := s.thinking.String()
finalText := s.text.String()
if s.bufferToolContent {
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) > 0 {
stopReason = "tool_use"
for i, tc := range detected {
idx := s.nextBlockIndex + i
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
"name": tc.Name,
"input": tc.Input,
},
})
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": idx,
})
}
s.nextBlockIndex += len(detected)
} else if finalText != "" {
idx := s.nextBlockIndex
s.nextBlockIndex++
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": idx,
"delta": map[string]any{
"type": "text_delta",
"text": finalText,
},
})
s.send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": idx,
})
}
}
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
s.send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": map[string]any{
"output_tokens": outputTokens,
},
})
s.send("message_stop", map[string]any{"type": "message_stop"})
}
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
if string(reason) == "upstream_error" {
s.sendError(s.upstreamErr)
return
}
if scannerErr != nil {
s.sendError(scannerErr.Error())
return
}
s.finalize("end_turn")
}

View File

@@ -1,313 +0,0 @@
package gemini
import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/adapter/openai"
"ds2api/internal/config"
"ds2api/internal/util"
)
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
requestedModel := strings.TrimSpace(routeModel)
if requestedModel == "" {
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
}
resolvedModel, ok := config.ResolveModel(store, requestedModel)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
messagesRaw := geminiMessagesFromRequest(req)
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
}
toolsRaw := convertGeminiTools(req["tools"])
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
passThrough := collectGeminiPassThrough(req)
return util.StandardRequest{
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}
func geminiMessagesFromRequest(req map[string]any) []any {
out := make([]any, 0, 8)
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
out = append(out, map[string]any{
"role": "system",
"content": sys,
})
}
contents, _ := req["contents"].([]any)
for _, item := range contents {
content, ok := item.(map[string]any)
if !ok {
continue
}
role := mapGeminiRole(content["role"])
if role == "" {
role = "user"
}
parts, _ := content["parts"].([]any)
if len(parts) == 0 {
if text := strings.TrimSpace(asString(content["text"])); text != "" {
out = append(out, map[string]any{
"role": role,
"content": text,
})
}
continue
}
textParts := make([]string, 0, len(parts))
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
textParts = textParts[:0]
}
for _, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
textParts = append(textParts, text)
continue
}
if fnCall, ok := part["functionCall"].(map[string]any); ok {
flushText()
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
callID := strings.TrimSpace(asString(fnCall["id"]))
if callID == "" {
callID = "call_gemini"
}
out = append(out, map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": stringifyJSON(fnCall["args"]),
},
},
},
})
}
continue
}
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
flushText()
name := strings.TrimSpace(asString(fnResp["name"]))
callID := strings.TrimSpace(asString(fnResp["id"]))
if callID == "" {
callID = strings.TrimSpace(asString(fnResp["callId"]))
}
if callID == "" {
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
}
if callID == "" {
callID = "call_gemini"
}
content := fnResp["response"]
if content == nil {
content = fnResp["output"]
}
if content == nil {
content = ""
}
msg := map[string]any{
"role": "tool",
"tool_call_id": callID,
"content": content,
}
if name != "" {
msg["name"] = name
}
out = append(out, msg)
}
}
flushText()
}
return out
}
func normalizeGeminiSystemInstruction(raw any) string {
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case map[string]any:
if parts, ok := v["parts"].([]any); ok {
texts := make([]string, 0, len(parts))
for _, item := range parts {
part, ok := item.(map[string]any)
if !ok {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
texts = append(texts, text)
}
}
return strings.Join(texts, "\n")
}
if text := strings.TrimSpace(asString(v["text"])); text != "" {
return text
}
}
return ""
}
func mapGeminiRole(v any) string {
switch strings.ToLower(strings.TrimSpace(asString(v))) {
case "user":
return "user"
case "model", "assistant":
return "assistant"
case "system":
return "system"
default:
return ""
}
}
func convertGeminiTools(raw any) []any {
tools, _ := raw.([]any)
if len(tools) == 0 {
return nil
}
out := make([]any, 0, len(tools))
for _, item := range tools {
tool, ok := item.(map[string]any)
if !ok {
continue
}
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
for _, declRaw := range fnDecls {
decl, ok := declRaw.(map[string]any)
if !ok {
continue
}
name := strings.TrimSpace(asString(decl["name"]))
if name == "" {
continue
}
function := map[string]any{
"name": name,
}
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
function["description"] = desc
}
if params, ok := decl["parameters"].(map[string]any); ok {
function["parameters"] = params
}
out = append(out, map[string]any{
"type": "function",
"function": function,
})
}
continue
}
// OpenAI-style passthrough fallback.
if _, ok := tool["function"].(map[string]any); ok {
out = append(out, tool)
continue
}
// Loose fallback for flattened function schema objects.
name := strings.TrimSpace(asString(tool["name"]))
if name == "" {
continue
}
fn := map[string]any{"name": name}
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
fn["description"] = desc
}
if params, ok := tool["parameters"].(map[string]any); ok {
fn["parameters"] = params
}
out = append(out, map[string]any{
"type": "function",
"function": fn,
})
}
if len(out) == 0 {
return nil
}
return out
}
func collectGeminiPassThrough(req map[string]any) map[string]any {
cfg, _ := req["generationConfig"].(map[string]any)
if len(cfg) == 0 {
return nil
}
out := map[string]any{}
if v, ok := cfg["temperature"]; ok {
out["temperature"] = v
}
if v, ok := cfg["topP"]; ok {
out["top_p"] = v
}
if v, ok := cfg["maxOutputTokens"]; ok {
out["max_tokens"] = v
}
if v, ok := cfg["stopSequences"]; ok {
out["stop"] = v
}
if len(out) == 0 {
return nil
}
return out
}
func asString(v any) string {
s, _ := v.(string)
return s
}
func stringifyJSON(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}

View File

@@ -0,0 +1,153 @@
package gemini
import "strings"
func geminiMessagesFromRequest(req map[string]any) []any {
out := make([]any, 0, 8)
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
out = append(out, map[string]any{
"role": "system",
"content": sys,
})
}
contents, _ := req["contents"].([]any)
for _, item := range contents {
content, ok := item.(map[string]any)
if !ok {
continue
}
role := mapGeminiRole(content["role"])
if role == "" {
role = "user"
}
parts, _ := content["parts"].([]any)
if len(parts) == 0 {
if text := strings.TrimSpace(asString(content["text"])); text != "" {
out = append(out, map[string]any{
"role": role,
"content": text,
})
}
continue
}
textParts := make([]string, 0, len(parts))
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
textParts = textParts[:0]
}
for _, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
textParts = append(textParts, text)
continue
}
if fnCall, ok := part["functionCall"].(map[string]any); ok {
flushText()
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
callID := strings.TrimSpace(asString(fnCall["id"]))
if callID == "" {
callID = "call_gemini"
}
out = append(out, map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": stringifyJSON(fnCall["args"]),
},
},
},
})
}
continue
}
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
flushText()
name := strings.TrimSpace(asString(fnResp["name"]))
callID := strings.TrimSpace(asString(fnResp["id"]))
if callID == "" {
callID = strings.TrimSpace(asString(fnResp["callId"]))
}
if callID == "" {
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
}
if callID == "" {
callID = "call_gemini"
}
content := fnResp["response"]
if content == nil {
content = fnResp["output"]
}
if content == nil {
content = ""
}
msg := map[string]any{
"role": "tool",
"tool_call_id": callID,
"content": content,
}
if name != "" {
msg["name"] = name
}
out = append(out, msg)
}
}
flushText()
}
return out
}
func normalizeGeminiSystemInstruction(raw any) string {
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case map[string]any:
if parts, ok := v["parts"].([]any); ok {
texts := make([]string, 0, len(parts))
for _, item := range parts {
part, ok := item.(map[string]any)
if !ok {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
texts = append(texts, text)
}
}
return strings.Join(texts, "\n")
}
if text := strings.TrimSpace(asString(v["text"])); text != "" {
return text
}
}
return ""
}
func mapGeminiRole(v any) string {
switch strings.ToLower(strings.TrimSpace(asString(v))) {
case "user":
return "user"
case "model", "assistant":
return "assistant"
case "system":
return "system"
default:
return ""
}
}

View File

@@ -0,0 +1,54 @@
package gemini
import (
"encoding/json"
"strings"
)
func collectGeminiPassThrough(req map[string]any) map[string]any {
cfg, _ := req["generationConfig"].(map[string]any)
if len(cfg) == 0 {
return nil
}
out := map[string]any{}
if v, ok := cfg["temperature"]; ok {
out["temperature"] = v
}
if v, ok := cfg["topP"]; ok {
out["top_p"] = v
}
if v, ok := cfg["maxOutputTokens"]; ok {
out["max_tokens"] = v
}
if v, ok := cfg["stopSequences"]; ok {
out["stop"] = v
}
if len(out) == 0 {
return nil
}
return out
}
func asString(v any) string {
s, _ := v.(string)
return s
}
func stringifyJSON(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}

View File

@@ -0,0 +1,46 @@
package gemini
import (
"fmt"
"strings"
"ds2api/internal/adapter/openai"
"ds2api/internal/config"
"ds2api/internal/util"
)
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
requestedModel := strings.TrimSpace(routeModel)
if requestedModel == "" {
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
}
resolvedModel, ok := config.ResolveModel(store, requestedModel)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
messagesRaw := geminiMessagesFromRequest(req)
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
}
toolsRaw := convertGeminiTools(req["tools"])
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
passThrough := collectGeminiPassThrough(req)
return util.StandardRequest{
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}

View File

@@ -0,0 +1,71 @@
package gemini
import "strings"
func convertGeminiTools(raw any) []any {
tools, _ := raw.([]any)
if len(tools) == 0 {
return nil
}
out := make([]any, 0, len(tools))
for _, item := range tools {
tool, ok := item.(map[string]any)
if !ok {
continue
}
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
for _, declRaw := range fnDecls {
decl, ok := declRaw.(map[string]any)
if !ok {
continue
}
name := strings.TrimSpace(asString(decl["name"]))
if name == "" {
continue
}
function := map[string]any{
"name": name,
}
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
function["description"] = desc
}
if params, ok := decl["parameters"].(map[string]any); ok {
function["parameters"] = params
}
out = append(out, map[string]any{
"type": "function",
"function": function,
})
}
continue
}
// OpenAI-style passthrough fallback.
if _, ok := tool["function"].(map[string]any); ok {
out = append(out, tool)
continue
}
// Loose fallback for flattened function schema objects.
name := strings.TrimSpace(asString(tool["name"]))
if name == "" {
continue
}
fn := map[string]any{"name": name}
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
fn["description"] = desc
}
if params, ok := tool["parameters"].(map[string]any); ok {
fn["parameters"] = params
}
out = append(out, map[string]any{
"type": "function",
"function": fn,
})
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -1,348 +0,0 @@
package gemini
import (
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
"ds2api/internal/deepseek"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
}
func RegisterRoutes(r chi.Router, h *Handler) {
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
}
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
h.handleGenerateContent(w, r, false)
}
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
h.handleGenerateContent(w, r, true)
}
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeGeminiError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
return
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, err.Error())
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if stream {
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
}
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
}, streamengine.ConsumeHooks{
OnParsed: runtime.onParsed,
OnFinalize: func(_ streamengine.StopReason, _ error) {
runtime.finalize()
},
})
}
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
return map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
"finishReason": "STOP",
},
},
"modelVersion": model,
"usageMetadata": usage,
}
}
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"promptTokenCount": promptTokens,
"candidatesTokenCount": reasoningTokens + completionTokens,
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
}
}
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
if len(detected) > 0 {
parts := make([]map[string]any, 0, len(detected))
for _, tc := range detected {
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": tc.Name,
"args": tc.Input,
},
})
}
return parts
}
text := finalText
if strings.TrimSpace(text) == "" {
text = finalThinking
}
return []map[string]any{{"text": text}}
}
type geminiStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
model string
finalPrompt string
thinkingEnabled bool
searchEnabled bool
bufferContent bool
toolNames []string
thinking strings.Builder
text strings.Builder
}
func newGeminiStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
model string,
finalPrompt string,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
) *geminiStreamRuntime {
return &geminiStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
model: model,
finalPrompt: finalPrompt,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
bufferContent: len(toolNames) > 0,
toolNames: toolNames,
}
}
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if s.thinkingEnabled {
s.thinking.WriteString(p.Text)
}
continue
}
s.text.WriteString(p.Text)
if s.bufferContent {
continue
}
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": []map[string]any{{"text": p.Text}},
},
},
},
"modelVersion": s.model,
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *geminiStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if s.bufferContent {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
},
},
"modelVersion": s.model,
})
}
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"finishReason": "STOP",
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
})
}
func writeGeminiError(w http.ResponseWriter, status int, message string) {
errorStatus := "INVALID_ARGUMENT"
switch status {
case http.StatusUnauthorized:
errorStatus = "UNAUTHENTICATED"
case http.StatusForbidden:
errorStatus = "PERMISSION_DENIED"
case http.StatusTooManyRequests:
errorStatus = "RESOURCE_EXHAUSTED"
case http.StatusNotFound:
errorStatus = "NOT_FOUND"
default:
if status >= 500 {
errorStatus = "INTERNAL"
}
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"code": status,
"message": message,
"status": errorStatus,
},
})
}

View File

@@ -0,0 +1,28 @@
package gemini
import "net/http"
func writeGeminiError(w http.ResponseWriter, status int, message string) {
errorStatus := "INVALID_ARGUMENT"
switch status {
case http.StatusUnauthorized:
errorStatus = "UNAUTHENTICATED"
case http.StatusForbidden:
errorStatus = "PERMISSION_DENIED"
case http.StatusTooManyRequests:
errorStatus = "RESOURCE_EXHAUSTED"
case http.StatusNotFound:
errorStatus = "NOT_FOUND"
default:
if status >= 500 {
errorStatus = "INTERNAL"
}
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"code": status,
"message": message,
"status": errorStatus,
},
})
}

View File

@@ -0,0 +1,135 @@
package gemini
import (
"encoding/json"
"io"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
"ds2api/internal/sse"
"ds2api/internal/util"
)
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeGeminiError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
return
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, err.Error())
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if stream {
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
}
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
return map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
"finishReason": "STOP",
},
},
"modelVersion": model,
"usageMetadata": usage,
}
}
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"promptTokenCount": promptTokens,
"candidatesTokenCount": reasoningTokens + completionTokens,
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
}
}
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
if len(detected) > 0 {
parts := make([]map[string]any, 0, len(detected))
for _, tc := range detected {
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": tc.Name,
"args": tc.Input,
},
})
}
return parts
}
text := finalText
if strings.TrimSpace(text) == "" {
text = finalThinking
}
return []map[string]any{{"text": text}}
}

View File

@@ -0,0 +1,32 @@
package gemini
import (
"net/http"
"github.com/go-chi/chi/v5"
"ds2api/internal/util"
)
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
}
func RegisterRoutes(r chi.Router, h *Handler) {
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
}
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
h.handleGenerateContent(w, r, false)
}
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
h.handleGenerateContent(w, r, true)
}

View File

@@ -0,0 +1,175 @@
package gemini
import (
"encoding/json"
"io"
"net/http"
"strings"
"time"
"ds2api/internal/deepseek"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
}, streamengine.ConsumeHooks{
OnParsed: runtime.onParsed,
OnFinalize: func(_ streamengine.StopReason, _ error) {
runtime.finalize()
},
})
}
type geminiStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
model string
finalPrompt string
thinkingEnabled bool
searchEnabled bool
bufferContent bool
toolNames []string
thinking strings.Builder
text strings.Builder
}
func newGeminiStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
model string,
finalPrompt string,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
) *geminiStreamRuntime {
return &geminiStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
model: model,
finalPrompt: finalPrompt,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
bufferContent: len(toolNames) > 0,
toolNames: toolNames,
}
}
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if s.thinkingEnabled {
s.thinking.WriteString(p.Text)
}
continue
}
s.text.WriteString(p.Text)
if s.bufferContent {
continue
}
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": []map[string]any{{"text": p.Text}},
},
},
},
"modelVersion": s.model,
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *geminiStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if s.bufferContent {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
},
},
"modelVersion": s.model,
})
}
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"finishReason": "STOP",
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
})
}

View File

@@ -32,4 +32,3 @@ func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
t.Fatal("expected param field")
}
}

View File

@@ -1,386 +0,0 @@
package openai
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
// writeJSON is a package-internal alias kept to avoid mass-renaming across
// every call-site in this file. It delegates to the shared util version.
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
leaseMu sync.Mutex
streamLeases map[string]streamLease
responsesMu sync.Mutex
responses *responseStore
}
type streamLease struct {
Auth *auth.RequestAuth
ExpiresAt time.Time
}
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/v1/models", h.ListModels)
r.Get("/v1/models/{model_id}", h.GetModel)
r.Post("/v1/chat/completions", h.ChatCompletions)
r.Post("/v1/responses", h.Responses)
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
r.Post("/v1/embeddings", h.Embeddings)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
}
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
model, ok := config.OpenAIModelByID(h.Store, modelID)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
return
}
writeJSON(w, http.StatusOK, model)
}
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
if isVercelStreamReleaseRequest(r) {
h.handleVercelStreamRelease(w, r)
return
}
if isVercelStreamPrepareRequest(r) {
h.handleVercelStreamPrepare(w, r)
return
}
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
r = r.WithContext(auth.WithAuth(r.Context(), a))
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if stdReq.Stream {
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, string(body))
return
}
_ = ctx
result := sse.CollectStream(resp, thinkingEnabled, true)
finalThinking := result.Thinking
finalText := result.Text
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, string(body))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
}
created := time.Now().Unix()
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamRuntime := newChatStreamRuntime(
w,
rc,
canFlush,
completionID,
created,
model,
finalPrompt,
thinkingEnabled,
searchEnabled,
toolNames,
bufferToolContent,
emitEarlyToolDeltas,
)
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
}, streamengine.ConsumeHooks{
OnKeepAlive: func() {
streamRuntime.sendKeepAlive()
},
OnParsed: streamRuntime.onParsed,
OnFinalize: func(reason streamengine.StopReason, _ error) {
if string(reason) == "content_filter" {
streamRuntime.finalize("content_filter")
return
}
streamRuntime.finalize("stop")
},
})
}
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))
for _, t := range tools {
tool, ok := t.(map[string]any)
if !ok {
continue
}
fn, _ := tool["function"].(map[string]any)
if len(fn) == 0 {
fn = tool
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
schema, _ := fn["parameters"].(map[string]any)
if name == "" {
name = "unknown"
}
names = append(names, name)
if desc == "" {
desc = "No description available"
}
b, _ := json.Marshal(schema)
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
}
if len(toolSchemas) == 0 {
return messages, names
}
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
for i := range messages {
if messages[i]["role"] == "system" {
old, _ := messages[i]["content"].(string)
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
return messages, names
}
}
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
return messages, names
}
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
if len(deltas) == 0 {
return nil
}
out := make([]map[string]any, 0, len(deltas))
for _, d := range deltas {
if d.Name == "" && d.Arguments == "" {
continue
}
callID, ok := ids[d.Index]
if !ok || callID == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
ids[d.Index] = callID
}
item := map[string]any{
"index": d.Index,
"id": callID,
"type": "function",
}
fn := map[string]any{}
if d.Name != "" {
fn["name"] = d.Name
}
if d.Arguments != "" {
fn["arguments"] = d.Arguments
}
if len(fn) > 0 {
item["function"] = fn
}
out = append(out, item)
}
return out
}
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
if len(calls) == 0 {
return nil
}
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
callID := ""
if ids != nil {
callID = strings.TrimSpace(ids[i])
}
if callID == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if ids != nil {
ids[i] = callID
}
}
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"index": i,
"id": callID,
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]any{
"error": map[string]any{
"message": message,
"type": openAIErrorType(status),
"code": openAIErrorCode(status),
"param": nil,
},
})
}
func openAIErrorType(status int) string {
switch status {
case http.StatusBadRequest:
return "invalid_request_error"
case http.StatusUnauthorized:
return "authentication_error"
case http.StatusForbidden:
return "permission_error"
case http.StatusTooManyRequests:
return "rate_limit_error"
case http.StatusServiceUnavailable:
return "service_unavailable_error"
default:
if status >= 500 {
return "api_error"
}
return "invalid_request_error"
}
}
func openAIErrorCode(status int) string {
switch status {
case http.StatusBadRequest:
return "invalid_request"
case http.StatusUnauthorized:
return "authentication_failed"
case http.StatusForbidden:
return "forbidden"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "not_found"
case http.StatusServiceUnavailable:
return "service_unavailable"
default:
if status >= 500 {
return "internal_error"
}
return "invalid_request"
}
}
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
for k, v := range collectOpenAIChatPassThrough(req) {
payload[k] = v
}
}
func (h *Handler) toolcallFeatureMatchEnabled() bool {
if h == nil || h.Store == nil {
return true
}
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
return mode == "" || mode == "feature_match"
}
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
if h == nil || h.Store == nil {
return true
}
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
return level == "" || level == "high"
}

View File

@@ -0,0 +1,156 @@
package openai
import (
"context"
"encoding/json"
"io"
"net/http"
"time"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
)
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
if isVercelStreamReleaseRequest(r) {
h.handleVercelStreamRelease(w, r)
return
}
if isVercelStreamPrepareRequest(r) {
h.handleVercelStreamPrepare(w, r)
return
}
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
r = r.WithContext(auth.WithAuth(r.Context(), a))
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if stdReq.Stream {
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, string(body))
return
}
_ = ctx
result := sse.CollectStream(resp, thinkingEnabled, true)
finalThinking := result.Thinking
finalText := result.Text
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, string(body))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
}
created := time.Now().Unix()
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamRuntime := newChatStreamRuntime(
w,
rc,
canFlush,
completionID,
created,
model,
finalPrompt,
thinkingEnabled,
searchEnabled,
toolNames,
bufferToolContent,
emitEarlyToolDeltas,
)
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
}, streamengine.ConsumeHooks{
OnKeepAlive: func() {
streamRuntime.sendKeepAlive()
},
OnParsed: streamRuntime.onParsed,
OnFinalize: func(reason streamengine.StopReason, _ error) {
if string(reason) == "content_filter" {
streamRuntime.finalize("content_filter")
return
}
streamRuntime.finalize("stop")
},
})
}

View File

@@ -0,0 +1,56 @@
package openai
import "net/http"
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]any{
"error": map[string]any{
"message": message,
"type": openAIErrorType(status),
"code": openAIErrorCode(status),
"param": nil,
},
})
}
func openAIErrorType(status int) string {
switch status {
case http.StatusBadRequest:
return "invalid_request_error"
case http.StatusUnauthorized:
return "authentication_error"
case http.StatusForbidden:
return "permission_error"
case http.StatusTooManyRequests:
return "rate_limit_error"
case http.StatusServiceUnavailable:
return "service_unavailable_error"
default:
if status >= 500 {
return "api_error"
}
return "invalid_request_error"
}
}
func openAIErrorCode(status int) string {
switch status {
case http.StatusBadRequest:
return "invalid_request"
case http.StatusUnauthorized:
return "authentication_failed"
case http.StatusForbidden:
return "forbidden"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "not_found"
case http.StatusServiceUnavailable:
return "service_unavailable"
default:
if status >= 500 {
return "internal_error"
}
return "invalid_request"
}
}

View File

@@ -0,0 +1,57 @@
package openai
import (
"net/http"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/util"
)
// writeJSON is a package-internal alias kept to avoid mass-renaming across
// every call-site in this package.
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
leaseMu sync.Mutex
streamLeases map[string]streamLease
responsesMu sync.Mutex
responses *responseStore
}
type streamLease struct {
Auth *auth.RequestAuth
ExpiresAt time.Time
}
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/v1/models", h.ListModels)
r.Get("/v1/models/{model_id}", h.GetModel)
r.Post("/v1/chat/completions", h.ChatCompletions)
r.Post("/v1/responses", h.Responses)
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
r.Post("/v1/embeddings", h.Embeddings)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
}
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
model, ok := config.OpenAIModelByID(h.Store, modelID)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
return
}
writeJSON(w, http.StatusOK, model)
}

View File

@@ -0,0 +1,116 @@
package openai
import (
"encoding/json"
"fmt"
"strings"
"github.com/google/uuid"
"ds2api/internal/util"
)
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))
for _, t := range tools {
tool, ok := t.(map[string]any)
if !ok {
continue
}
fn, _ := tool["function"].(map[string]any)
if len(fn) == 0 {
fn = tool
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
schema, _ := fn["parameters"].(map[string]any)
if name == "" {
name = "unknown"
}
names = append(names, name)
if desc == "" {
desc = "No description available"
}
b, _ := json.Marshal(schema)
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
}
if len(toolSchemas) == 0 {
return messages, names
}
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
for i := range messages {
if messages[i]["role"] == "system" {
old, _ := messages[i]["content"].(string)
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
return messages, names
}
}
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
return messages, names
}
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
if len(deltas) == 0 {
return nil
}
out := make([]map[string]any, 0, len(deltas))
for _, d := range deltas {
if d.Name == "" && d.Arguments == "" {
continue
}
callID, ok := ids[d.Index]
if !ok || callID == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
ids[d.Index] = callID
}
item := map[string]any{
"index": d.Index,
"id": callID,
"type": "function",
}
fn := map[string]any{}
if d.Name != "" {
fn["name"] = d.Name
}
if d.Arguments != "" {
fn["arguments"] = d.Arguments
}
if len(fn) > 0 {
item["function"] = fn
}
out = append(out, item)
}
return out
}
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
if len(calls) == 0 {
return nil
}
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
callID := ""
if ids != nil {
callID = strings.TrimSpace(ids[i])
}
if callID == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if ids != nil {
ids[i] = callID
}
}
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"index": i,
"id": callID,
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}

View File

@@ -0,0 +1,25 @@
package openai
import "strings"
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
for k, v := range collectOpenAIChatPassThrough(req) {
payload[k] = v
}
}
func (h *Handler) toolcallFeatureMatchEnabled() bool {
if h == nil || h.Store == nil {
return true
}
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
return mode == "" || mode == "feature_match"
}
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
if h == nil || h.Store == nil {
return true
}
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
return level == "" || level == "high"
}

View File

@@ -2,7 +2,6 @@ package openai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
@@ -170,264 +169,3 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
},
})
}
func responsesMessagesFromRequest(req map[string]any) []any {
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
if rawInput, ok := req["input"]; ok {
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
}
return nil
}
func prependInstructionMessage(messages []any, instructions any) []any {
sys, _ := instructions.(string)
sys = strings.TrimSpace(sys)
if sys == "" {
return messages
}
out := make([]any, 0, len(messages)+1)
out = append(out, map[string]any{"role": "system", "content": sys})
out = append(out, messages...)
return out
}
func normalizeResponsesInputAsMessages(input any) []any {
switch v := input.(type) {
case string:
if strings.TrimSpace(v) == "" {
return nil
}
return []any{map[string]any{"role": "user", "content": v}}
case []any:
return normalizeResponsesInputArray(v)
case map[string]any:
if msg := normalizeResponsesInputItem(v); msg != nil {
return []any{msg}
}
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
return []any{map[string]any{"role": "user", "content": txt}}
}
if content, ok := v["content"]; ok {
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
return []any{map[string]any{"role": "user", "content": content}}
}
}
}
return nil
}
func normalizeResponsesInputArray(items []any) []any {
if len(items) == 0 {
return nil
}
out := make([]any, 0, len(items))
fallbackParts := make([]string, 0, len(items))
flushFallback := func() {
if len(fallbackParts) == 0 {
return
}
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
fallbackParts = fallbackParts[:0]
}
for _, item := range items {
switch x := item.(type) {
case map[string]any:
if msg := normalizeResponsesInputItem(x); msg != nil {
flushFallback()
out = append(out, msg)
continue
}
if s := normalizeResponsesFallbackPart(x); s != "" {
fallbackParts = append(fallbackParts, s)
}
default:
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
fallbackParts = append(fallbackParts, s)
}
}
}
flushFallback()
if len(out) == 0 {
return nil
}
return out
}
func normalizeResponsesInputItem(m map[string]any) map[string]any {
if m == nil {
return nil
}
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
if role != "" {
content := m["content"]
if content == nil {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
content = txt
}
}
if content == nil {
return nil
}
return map[string]any{
"role": role,
"content": content,
}
}
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
switch itemType {
case "message", "input_message":
content := m["content"]
if content == nil {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
content = txt
}
}
if content == nil {
return nil
}
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
if role == "" {
role = "user"
}
return map[string]any{
"role": role,
"content": content,
}
case "function_call_output", "tool_result":
content := m["output"]
if content == nil {
content = m["content"]
}
if content == nil {
content = ""
}
out := map[string]any{
"role": "tool",
"content": content,
}
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
out["tool_call_id"] = callID
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
out["tool_call_id"] = callID
}
if name := strings.TrimSpace(asString(m["name"])); name != "" {
out["name"] = name
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
out["name"] = name
}
return out
case "function_call", "tool_call":
name := strings.TrimSpace(asString(m["name"]))
var fn map[string]any
if rawFn, ok := m["function"].(map[string]any); ok {
fn = rawFn
if name == "" {
name = strings.TrimSpace(asString(fn["name"]))
}
}
if name == "" {
return nil
}
var argsRaw any
if v, ok := m["arguments"]; ok {
argsRaw = v
} else if v, ok := m["input"]; ok {
argsRaw = v
}
if argsRaw == nil && fn != nil {
if v, ok := fn["arguments"]; ok {
argsRaw = v
} else if v, ok := fn["input"]; ok {
argsRaw = v
}
}
functionPayload := map[string]any{
"name": name,
"arguments": stringifyToolCallArguments(argsRaw),
}
call := map[string]any{
"type": "function",
"function": functionPayload,
}
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
call["id"] = callID
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
call["id"] = callID
}
return map[string]any{
"role": "assistant",
"tool_calls": []any{call},
}
case "input_text":
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return map[string]any{
"role": "user",
"content": txt,
}
}
}
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return map[string]any{
"role": "user",
"content": txt,
}
}
if content, ok := m["content"]; ok {
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
return map[string]any{
"role": "user",
"content": content,
}
}
}
return nil
}
func normalizeResponsesFallbackPart(m map[string]any) string {
if m == nil {
return ""
}
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return txt
}
}
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return txt
}
if content, ok := m["content"]; ok {
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
return normalized
}
}
return strings.TrimSpace(fmt.Sprintf("%v", m))
}
func stringifyToolCallArguments(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}

View File

@@ -0,0 +1,181 @@
package openai
import (
"encoding/json"
"fmt"
"strings"
)
func normalizeResponsesInputItem(m map[string]any) map[string]any {
if m == nil {
return nil
}
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
if role != "" {
content := m["content"]
if content == nil {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
content = txt
}
}
if content == nil {
return nil
}
return map[string]any{
"role": role,
"content": content,
}
}
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
switch itemType {
case "message", "input_message":
content := m["content"]
if content == nil {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
content = txt
}
}
if content == nil {
return nil
}
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
if role == "" {
role = "user"
}
return map[string]any{
"role": role,
"content": content,
}
case "function_call_output", "tool_result":
content := m["output"]
if content == nil {
content = m["content"]
}
if content == nil {
content = ""
}
out := map[string]any{
"role": "tool",
"content": content,
}
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
out["tool_call_id"] = callID
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
out["tool_call_id"] = callID
}
if name := strings.TrimSpace(asString(m["name"])); name != "" {
out["name"] = name
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
out["name"] = name
}
return out
case "function_call", "tool_call":
name := strings.TrimSpace(asString(m["name"]))
var fn map[string]any
if rawFn, ok := m["function"].(map[string]any); ok {
fn = rawFn
if name == "" {
name = strings.TrimSpace(asString(fn["name"]))
}
}
if name == "" {
return nil
}
var argsRaw any
if v, ok := m["arguments"]; ok {
argsRaw = v
} else if v, ok := m["input"]; ok {
argsRaw = v
}
if argsRaw == nil && fn != nil {
if v, ok := fn["arguments"]; ok {
argsRaw = v
} else if v, ok := fn["input"]; ok {
argsRaw = v
}
}
functionPayload := map[string]any{
"name": name,
"arguments": stringifyToolCallArguments(argsRaw),
}
call := map[string]any{
"type": "function",
"function": functionPayload,
}
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
call["id"] = callID
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
call["id"] = callID
}
return map[string]any{
"role": "assistant",
"tool_calls": []any{call},
}
case "input_text":
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return map[string]any{
"role": "user",
"content": txt,
}
}
}
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return map[string]any{
"role": "user",
"content": txt,
}
}
if content, ok := m["content"]; ok {
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
return map[string]any{
"role": "user",
"content": content,
}
}
}
return nil
}
func normalizeResponsesFallbackPart(m map[string]any) string {
if m == nil {
return ""
}
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return txt
}
}
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return txt
}
if content, ok := m["content"]; ok {
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
return normalized
}
}
return strings.TrimSpace(fmt.Sprintf("%v", m))
}
func stringifyToolCallArguments(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}

View File

@@ -0,0 +1,93 @@
package openai
import (
"fmt"
"strings"
)
func responsesMessagesFromRequest(req map[string]any) []any {
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
if rawInput, ok := req["input"]; ok {
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
}
return nil
}
func prependInstructionMessage(messages []any, instructions any) []any {
sys, _ := instructions.(string)
sys = strings.TrimSpace(sys)
if sys == "" {
return messages
}
out := make([]any, 0, len(messages)+1)
out = append(out, map[string]any{"role": "system", "content": sys})
out = append(out, messages...)
return out
}
func normalizeResponsesInputAsMessages(input any) []any {
switch v := input.(type) {
case string:
if strings.TrimSpace(v) == "" {
return nil
}
return []any{map[string]any{"role": "user", "content": v}}
case []any:
return normalizeResponsesInputArray(v)
case map[string]any:
if msg := normalizeResponsesInputItem(v); msg != nil {
return []any{msg}
}
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
return []any{map[string]any{"role": "user", "content": txt}}
}
if content, ok := v["content"]; ok {
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
return []any{map[string]any{"role": "user", "content": content}}
}
}
}
return nil
}
func normalizeResponsesInputArray(items []any) []any {
if len(items) == 0 {
return nil
}
out := make([]any, 0, len(items))
fallbackParts := make([]string, 0, len(items))
flushFallback := func() {
if len(fallbackParts) == 0 {
return
}
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
fallbackParts = fallbackParts[:0]
}
for _, item := range items {
switch x := item.(type) {
case map[string]any:
if msg := normalizeResponsesInputItem(x); msg != nil {
flushFallback()
out = append(out, msg)
continue
}
if s := normalizeResponsesFallbackPart(x); s != "" {
fallbackParts = append(fallbackParts, s)
}
default:
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
fallbackParts = append(fallbackParts, s)
}
}
}
flushFallback()
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -1,366 +0,0 @@
package openai
import (
"encoding/json"
"net/http"
"sort"
"strings"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
"github.com/google/uuid"
)
type responsesStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
responseID string
model string
finalPrompt string
toolNames []string
thinkingEnabled bool
searchEnabled bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
toolCallsDoneEmitted bool
sieve toolStreamSieveState
thinkingSieve toolStreamSieveState
thinking strings.Builder
text strings.Builder
streamToolCallIDs map[int]string
streamFunctionIDs map[int]string
functionDone map[int]bool
toolCallsDoneSigs map[string]bool
reasoningItemID string
persistResponse func(obj map[string]any)
}
func newResponsesStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
responseID string,
model string,
finalPrompt string,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
bufferToolContent bool,
emitEarlyToolDeltas bool,
persistResponse func(obj map[string]any),
) *responsesStreamRuntime {
return &responsesStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
responseID: responseID,
model: model,
finalPrompt: finalPrompt,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
toolNames: toolNames,
bufferToolContent: bufferToolContent,
emitEarlyToolDeltas: emitEarlyToolDeltas,
streamToolCallIDs: map[int]string{},
streamFunctionIDs: map[int]string{},
functionDone: map[int]bool{},
toolCallsDoneSigs: map[string]bool{},
persistResponse: persistResponse,
}
}
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = s.w.Write([]byte("event: " + event + "\n"))
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *responsesStreamRuntime) sendCreated() {
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
}
func (s *responsesStreamRuntime) sendDone() {
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *responsesStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if strings.TrimSpace(finalThinking) != "" {
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
}
if s.bufferToolContent {
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
}
// Compatibility fallback: some streams only emit incremental tool deltas.
// Ensure final function_call_arguments.done is emitted at least once.
if s.toolCallsEmitted {
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) == 0 {
detected = util.ParseToolCalls(finalThinking, s.toolNames)
}
if len(detected) > 0 {
if !s.toolCallsDoneEmitted {
s.emitToolCallsDone(detected)
} else {
s.emitFunctionCallDoneEvents(detected)
}
}
}
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
if s.toolCallsEmitted {
s.alignCompletedOutputCallIDs(obj)
}
if s.toolCallsEmitted {
obj["status"] = "completed"
}
if s.persistResponse != nil {
s.persistResponse(obj)
}
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
s.sendDone()
}
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if !s.thinkingEnabled {
continue
}
s.thinking.WriteString(p.Text)
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
if s.bufferToolContent {
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
}
continue
}
s.text.WriteString(p.Text)
if !s.bufferToolContent {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
continue
}
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
for _, evt := range events {
if emitContent && evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
}
if len(evt.ToolCalls) > 0 {
s.emitToolCallsDone(evt.ToolCalls)
}
}
}
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
if len(calls) == 0 {
return
}
sig := toolCallListSignature(calls)
if sig != "" && s.toolCallsDoneSigs[sig] {
return
}
if sig != "" {
s.toolCallsDoneSigs[sig] = true
}
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
if len(formatted) == 0 {
return
}
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
s.emitFunctionCallDoneEvents(calls)
}
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
if strings.TrimSpace(s.reasoningItemID) != "" {
return s.reasoningItemID
}
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
return s.reasoningItemID
}
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
return id
}
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
s.streamFunctionIDs[index] = id
return id
}
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
return id
}
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
s.streamToolCallIDs[index] = id
return id
}
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
if strings.TrimSpace(s.thinking.String()) != "" {
return 1
}
return 0
}
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
for _, d := range deltas {
if strings.TrimSpace(d.Arguments) == "" {
continue
}
outputIndex := s.functionOutputBaseIndex() + d.Index
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(d.Index)
s.sendEvent(
"response.function_call_arguments.delta",
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
)
}
}
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
base := s.functionOutputBaseIndex()
for idx, tc := range calls {
if strings.TrimSpace(tc.Name) == "" {
continue
}
outputIndex := base + idx
if s.functionDone[outputIndex] {
continue
}
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(idx)
argsBytes, _ := json.Marshal(tc.Input)
s.sendEvent(
"response.function_call_arguments.done",
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
)
s.functionDone[outputIndex] = true
}
}
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
if obj == nil || len(s.streamToolCallIDs) == 0 {
return
}
output, _ := obj["output"].([]any)
if len(output) == 0 {
return
}
indices := make([]int, 0, len(s.streamToolCallIDs))
for idx := range s.streamToolCallIDs {
indices = append(indices, idx)
}
sort.Ints(indices)
ordered := make([]string, 0, len(indices))
for _, idx := range indices {
id := strings.TrimSpace(s.streamToolCallIDs[idx])
if id == "" {
continue
}
ordered = append(ordered, id)
}
if len(ordered) == 0 {
return
}
functionIdx := 0
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
typ, _ := m["type"].(string)
switch typ {
case "function_call":
if functionIdx < len(ordered) {
m["call_id"] = ordered[functionIdx]
functionIdx++
}
case "tool_calls":
tcArr, _ := m["tool_calls"].([]any)
for i, raw := range tcArr {
tc, _ := raw.(map[string]any)
if tc == nil {
continue
}
if i < len(ordered) {
tc["id"] = ordered[i]
}
}
}
}
}
func toolCallListSignature(calls []util.ParsedToolCall) string {
if len(calls) == 0 {
return ""
}
var b strings.Builder
for i, tc := range calls {
if i > 0 {
b.WriteString("|")
}
b.WriteString(strings.TrimSpace(tc.Name))
b.WriteString(":")
args, _ := json.Marshal(tc.Input)
b.Write(args)
}
return b.String()
}

View File

@@ -0,0 +1,157 @@
package openai
import (
"net/http"
"strings"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
type responsesStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
responseID string
model string
finalPrompt string
toolNames []string
thinkingEnabled bool
searchEnabled bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
toolCallsDoneEmitted bool
sieve toolStreamSieveState
thinkingSieve toolStreamSieveState
thinking strings.Builder
text strings.Builder
streamToolCallIDs map[int]string
streamFunctionIDs map[int]string
functionDone map[int]bool
toolCallsDoneSigs map[string]bool
reasoningItemID string
persistResponse func(obj map[string]any)
}
func newResponsesStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
responseID string,
model string,
finalPrompt string,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
bufferToolContent bool,
emitEarlyToolDeltas bool,
persistResponse func(obj map[string]any),
) *responsesStreamRuntime {
return &responsesStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
responseID: responseID,
model: model,
finalPrompt: finalPrompt,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
toolNames: toolNames,
bufferToolContent: bufferToolContent,
emitEarlyToolDeltas: emitEarlyToolDeltas,
streamToolCallIDs: map[int]string{},
streamFunctionIDs: map[int]string{},
functionDone: map[int]bool{},
toolCallsDoneSigs: map[string]bool{},
persistResponse: persistResponse,
}
}
func (s *responsesStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if strings.TrimSpace(finalThinking) != "" {
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
}
if s.bufferToolContent {
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
}
// Compatibility fallback: some streams only emit incremental tool deltas.
// Ensure final function_call_arguments.done is emitted at least once.
if s.toolCallsEmitted {
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) == 0 {
detected = util.ParseToolCalls(finalThinking, s.toolNames)
}
if len(detected) > 0 {
if !s.toolCallsDoneEmitted {
s.emitToolCallsDone(detected)
} else {
s.emitFunctionCallDoneEvents(detected)
}
}
}
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
if s.toolCallsEmitted {
s.alignCompletedOutputCallIDs(obj)
}
if s.toolCallsEmitted {
obj["status"] = "completed"
}
if s.persistResponse != nil {
s.persistResponse(obj)
}
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
s.sendDone()
}
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if !s.thinkingEnabled {
continue
}
s.thinking.WriteString(p.Text)
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
if s.bufferToolContent {
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
}
continue
}
s.text.WriteString(p.Text)
if !s.bufferToolContent {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
continue
}
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}

View File

@@ -0,0 +1,52 @@
package openai
import (
"encoding/json"
openaifmt "ds2api/internal/format/openai"
)
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = s.w.Write([]byte("event: " + event + "\n"))
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *responsesStreamRuntime) sendCreated() {
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
}
func (s *responsesStreamRuntime) sendDone() {
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
for _, evt := range events {
if emitContent && evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
}
if len(evt.ToolCalls) > 0 {
s.emitToolCallsDone(evt.ToolCalls)
}
}
}

View File

@@ -0,0 +1,172 @@
package openai
import (
"encoding/json"
"sort"
"strings"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/util"
"github.com/google/uuid"
)
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
if len(calls) == 0 {
return
}
sig := toolCallListSignature(calls)
if sig != "" && s.toolCallsDoneSigs[sig] {
return
}
if sig != "" {
s.toolCallsDoneSigs[sig] = true
}
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
if len(formatted) == 0 {
return
}
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
s.emitFunctionCallDoneEvents(calls)
}
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
if strings.TrimSpace(s.reasoningItemID) != "" {
return s.reasoningItemID
}
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
return s.reasoningItemID
}
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
return id
}
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
s.streamFunctionIDs[index] = id
return id
}
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
return id
}
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
s.streamToolCallIDs[index] = id
return id
}
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
if strings.TrimSpace(s.thinking.String()) != "" {
return 1
}
return 0
}
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
for _, d := range deltas {
if strings.TrimSpace(d.Arguments) == "" {
continue
}
outputIndex := s.functionOutputBaseIndex() + d.Index
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(d.Index)
s.sendEvent(
"response.function_call_arguments.delta",
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
)
}
}
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
base := s.functionOutputBaseIndex()
for idx, tc := range calls {
if strings.TrimSpace(tc.Name) == "" {
continue
}
outputIndex := base + idx
if s.functionDone[outputIndex] {
continue
}
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(idx)
argsBytes, _ := json.Marshal(tc.Input)
s.sendEvent(
"response.function_call_arguments.done",
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
)
s.functionDone[outputIndex] = true
}
}
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
if obj == nil || len(s.streamToolCallIDs) == 0 {
return
}
output, _ := obj["output"].([]any)
if len(output) == 0 {
return
}
indices := make([]int, 0, len(s.streamToolCallIDs))
for idx := range s.streamToolCallIDs {
indices = append(indices, idx)
}
sort.Ints(indices)
ordered := make([]string, 0, len(indices))
for _, idx := range indices {
id := strings.TrimSpace(s.streamToolCallIDs[idx])
if id == "" {
continue
}
ordered = append(ordered, id)
}
if len(ordered) == 0 {
return
}
functionIdx := 0
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
typ, _ := m["type"].(string)
switch typ {
case "function_call":
if functionIdx < len(ordered) {
m["call_id"] = ordered[functionIdx]
functionIdx++
}
case "tool_calls":
tcArr, _ := m["tool_calls"].([]any)
for i, raw := range tcArr {
tc, _ := raw.(map[string]any)
if tc == nil {
continue
}
if i < len(ordered) {
tc["id"] = ordered[i]
}
}
}
}
}
func toolCallListSignature(calls []util.ParsedToolCall) string {
if len(calls) == 0 {
return ""
}
var b strings.Builder
for i, tc := range calls {
if i > 0 {
b.WriteString("|")
}
b.WriteString(strings.TrimSpace(tc.Name))
b.WriteString(":")
args, _ := json.Marshal(tc.Input)
b.Write(args)
}
return b.String()
}

View File

@@ -1,713 +0,0 @@
package openai
import (
"strings"
"ds2api/internal/util"
)
type toolStreamSieveState struct {
pending strings.Builder
capture strings.Builder
capturing bool
recentTextTail string
disableDeltas bool
toolNameSent bool
toolName string
toolArgsStart int
toolArgsSent int
toolArgsString bool
toolArgsDone bool
}
type toolStreamEvent struct {
Content string
ToolCalls []util.ParsedToolCall
ToolCallDeltas []toolCallDelta
}
type toolCallDelta struct {
Index int
Name string
Arguments string
}
const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.disableDeltas = false
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
s.toolArgsSent = -1
s.toolArgsString = false
s.toolArgsDone = false
}
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
if state == nil {
return nil
}
if chunk != "" {
state.pending.WriteString(chunk)
}
events := make([]toolStreamEvent, 0, 2)
for {
if state.capturing {
if state.pending.Len() > 0 {
state.capture.WriteString(state.pending.String())
state.pending.Reset()
}
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
}
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
if !ready {
if state.capture.Len() > toolSieveCaptureLimit {
content := state.capture.String()
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
continue
}
break
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
if len(calls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: calls})
}
if suffix != "" {
state.pending.WriteString(suffix)
}
continue
}
pending := state.pending.String()
if pending == "" {
break
}
start := findToolSegmentStart(pending)
if start >= 0 {
prefix := pending[:start]
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
state.pending.Reset()
state.capture.WriteString(pending[start:])
state.capturing = true
state.resetIncrementalToolState()
continue
}
safe, hold := splitSafeContentForToolDetection(pending)
if safe == "" {
break
}
state.pending.Reset()
state.pending.WriteString(hold)
state.noteText(safe)
events = append(events, toolStreamEvent{Content: safe})
}
return events
}
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
if state == nil {
return nil
}
events := processToolSieveChunk(state, "", toolNames)
if state.capturing {
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
if ready {
if consumedPrefix != "" {
state.noteText(consumedPrefix)
events = append(events, toolStreamEvent{Content: consumedPrefix})
}
if len(consumedCalls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
}
if consumedSuffix != "" {
state.noteText(consumedSuffix)
events = append(events, toolStreamEvent{Content: consumedSuffix})
}
} else {
content := state.capture.String()
if content != "" {
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
}
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
}
if state.pending.Len() > 0 {
content := state.pending.String()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
state.pending.Reset()
}
return events
}
func splitSafeContentForToolDetection(s string) (safe, hold string) {
if s == "" {
return "", ""
}
suspiciousStart := findSuspiciousPrefixStart(s)
if suspiciousStart < 0 {
return s, ""
}
if suspiciousStart > 0 {
return s[:suspiciousStart], s[suspiciousStart:]
}
// If suspicious content starts at position 0, keep holding until we can
// parse a complete tool JSON block or reach stream flush.
return "", s
}
func findSuspiciousPrefixStart(s string) int {
start := -1
indices := []int{
strings.LastIndex(s, "{"),
strings.LastIndex(s, "["),
strings.LastIndex(s, "```"),
}
for _, idx := range indices {
if idx > start {
start = idx
}
}
return start
}
func findToolSegmentStart(s string) int {
if s == "" {
return -1
}
lower := strings.ToLower(s)
offset := 0
for {
keyRel := strings.Index(lower[offset:], "tool_calls")
if keyRel < 0 {
return -1
}
keyIdx := offset + keyRel
start := strings.LastIndex(s[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
if !insideCodeFence(s[:start]) {
return start
}
offset = keyIdx + len("tool_calls")
}
}
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
captured := state.capture.String()
if captured == "" {
return "", nil, "", false
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return "", nil, "", false
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return "", nil, "", false
}
obj, end, ok := extractJSONObjectFrom(captured, start)
if !ok {
return "", nil, "", false
}
prefixPart := captured[:start]
suffixPart := captured[end:]
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
return captured, nil, "", true
}
return prefixPart, parsed, suffixPart, true
}
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
end := i + 1
return text[start:end], end, true
}
}
}
return "", 0, false
}
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if state.disableDeltas {
return nil
}
captured := state.capture.String()
if captured == "" {
return nil
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return nil
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return nil
}
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
if hasMultiple {
state.disableDeltas = true
return nil
}
if !certainSingle {
// In uncertain phases (e.g. first call arrived but array not closed yet),
// avoid speculative deltas and wait for final parsed tool_calls payload.
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
}
deltas := make([]toolCallDelta, 0, 2)
if state.toolName == "" {
name, ok := extractToolCallName(captured, callStart)
if !ok || name == "" {
return nil
}
state.toolName = name
}
if state.toolArgsStart < 0 {
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
if ok {
state.toolArgsString = stringMode
if stringMode {
state.toolArgsStart = argsStart + 1
} else {
state.toolArgsStart = argsStart
}
state.toolArgsSent = state.toolArgsStart
}
}
if !state.toolNameSent {
if state.toolArgsStart < 0 {
return nil
}
state.toolNameSent = true
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
}
if state.toolArgsStart < 0 || state.toolArgsDone {
return deltas
}
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
if !ok {
return deltas
}
if end > state.toolArgsSent {
deltas = append(deltas, toolCallDelta{
Index: 0,
Arguments: captured[state.toolArgsSent:end],
})
state.toolArgsSent = end
}
if complete {
state.toolArgsDone = true
}
return deltas
}
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return false, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return false, false
}
count := 0
depth := 0
quote := byte(0)
escaped := false
for ; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
if depth == 0 {
count++
if count > 1 {
return false, true
}
}
depth++
continue
}
if ch == '}' {
if depth > 0 {
depth--
}
continue
}
if ch == ',' && depth == 0 {
// top-level separator means at least one more tool call exists
// (or is expected). Treat as multi-call and stop incremental deltas.
return false, true
}
if ch == ']' && depth == 0 {
return count == 1, false
}
}
// array not closed yet: still uncertain whether more calls will appear
return false, false
}
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return -1, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return -1, false
}
return i, true
}
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
i := keyIdx + len("tool_calls")
for i < len(text) && text[i] != ':' {
i++
}
if i >= len(text) {
return -1, false
}
i = skipSpaces(text, i+1)
if i >= len(text) || text[i] != '[' {
return -1, false
}
return i, true
}
func extractToolCallName(text string, callStart int) (string, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return "", false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
return "", false
}
}
name, _, ok := parseJSONStringLiteral(text, valueStart)
if !ok {
return "", false
}
return name, true
}
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
keys := []string{"input", "arguments", "args", "parameters", "params"}
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
if !ok {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return -1, false, false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
if !ok {
return -1, false, false
}
}
if valueStart >= len(text) {
return -1, false, false
}
ch := text[valueStart]
if ch == '{' || ch == '[' {
return valueStart, false, true
}
if ch == '"' {
return valueStart, true, true
}
return -1, false, false
}
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
if start < 0 || start > len(text) {
return 0, false, false
}
if stringMode {
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return i, true, true
}
}
return len(text), false, true
}
if start >= len(text) {
return start, false, false
}
if text[start] != '{' && text[start] != '[' {
return 0, false, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' || ch == '[' {
depth++
continue
}
if ch == '}' || ch == ']' {
depth--
if depth == 0 {
return i + 1, true, true
}
}
}
return len(text), false, true
}
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
return 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := objStart; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
if depth == 1 {
key, end, ok := parseJSONStringLiteral(text, i)
if !ok {
return 0, false
}
j := skipSpaces(text, end)
if j >= len(text) || text[j] != ':' {
i = end - 1
continue
}
j = skipSpaces(text, j+1)
if j >= len(text) {
return 0, false
}
if containsKey(keys, key) {
return j, true
}
i = j - 1
continue
}
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
break
}
}
}
return 0, false
}
func findFunctionObjectStart(text string, callStart int) (int, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
return -1, false
}
return valueStart, true
}
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '"' {
return "", 0, false
}
var b strings.Builder
escaped := false
for i := start + 1; i < len(text); i++ {
ch := text[i]
if escaped {
b.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return b.String(), i + 1, true
}
b.WriteByte(ch)
}
return "", 0, false
}
func containsKey(keys []string, value string) bool {
for _, k := range keys {
if k == value {
return true
}
}
return false
}
func skipSpaces(text string, i int) int {
for i < len(text) {
switch text[i] {
case ' ', '\t', '\n', '\r':
i++
default:
return i
}
}
return i
}
func (s *toolStreamSieveState) noteText(content string) {
if strings.TrimSpace(content) == "" {
return
}
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
}
func appendTail(prev, next string, max int) string {
if max <= 0 {
return ""
}
combined := prev + next
if len(combined) <= max {
return combined
}
return combined[len(combined)-max:]
}
func looksLikeToolExampleContext(text string) bool {
return insideCodeFence(text)
}
func insideCodeFence(text string) bool {
if text == "" {
return false
}
return strings.Count(text, "```")%2 == 1
}

View File

@@ -0,0 +1,208 @@
package openai
import (
"strings"
"ds2api/internal/util"
)
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
if state == nil {
return nil
}
if chunk != "" {
state.pending.WriteString(chunk)
}
events := make([]toolStreamEvent, 0, 2)
for {
if state.capturing {
if state.pending.Len() > 0 {
state.capture.WriteString(state.pending.String())
state.pending.Reset()
}
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
}
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
if !ready {
if state.capture.Len() > toolSieveCaptureLimit {
content := state.capture.String()
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
continue
}
break
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
if len(calls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: calls})
}
if suffix != "" {
state.pending.WriteString(suffix)
}
continue
}
pending := state.pending.String()
if pending == "" {
break
}
start := findToolSegmentStart(pending)
if start >= 0 {
prefix := pending[:start]
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
state.pending.Reset()
state.capture.WriteString(pending[start:])
state.capturing = true
state.resetIncrementalToolState()
continue
}
safe, hold := splitSafeContentForToolDetection(pending)
if safe == "" {
break
}
state.pending.Reset()
state.pending.WriteString(hold)
state.noteText(safe)
events = append(events, toolStreamEvent{Content: safe})
}
return events
}
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
if state == nil {
return nil
}
events := processToolSieveChunk(state, "", toolNames)
if state.capturing {
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
if ready {
if consumedPrefix != "" {
state.noteText(consumedPrefix)
events = append(events, toolStreamEvent{Content: consumedPrefix})
}
if len(consumedCalls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
}
if consumedSuffix != "" {
state.noteText(consumedSuffix)
events = append(events, toolStreamEvent{Content: consumedSuffix})
}
} else {
content := state.capture.String()
if content != "" {
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
}
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
}
if state.pending.Len() > 0 {
content := state.pending.String()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
state.pending.Reset()
}
return events
}
func splitSafeContentForToolDetection(s string) (safe, hold string) {
if s == "" {
return "", ""
}
suspiciousStart := findSuspiciousPrefixStart(s)
if suspiciousStart < 0 {
return s, ""
}
if suspiciousStart > 0 {
return s[:suspiciousStart], s[suspiciousStart:]
}
// If suspicious content starts at position 0, keep holding until we can
// parse a complete tool JSON block or reach stream flush.
return "", s
}
func findSuspiciousPrefixStart(s string) int {
start := -1
indices := []int{
strings.LastIndex(s, "{"),
strings.LastIndex(s, "["),
strings.LastIndex(s, "```"),
}
for _, idx := range indices {
if idx > start {
start = idx
}
}
return start
}
func findToolSegmentStart(s string) int {
if s == "" {
return -1
}
lower := strings.ToLower(s)
offset := 0
for {
keyRel := strings.Index(lower[offset:], "tool_calls")
if keyRel < 0 {
return -1
}
keyIdx := offset + keyRel
start := strings.LastIndex(s[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
if !insideCodeFence(s[:start]) {
return start
}
offset = keyIdx + len("tool_calls")
}
}
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
captured := state.capture.String()
if captured == "" {
return "", nil, "", false
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return "", nil, "", false
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return "", nil, "", false
}
obj, end, ok := extractJSONObjectFrom(captured, start)
if !ok {
return "", nil, "", false
}
prefixPart := captured[:start]
suffixPart := captured[end:]
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
return captured, nil, "", true
}
return prefixPart, parsed, suffixPart, true
}

View File

@@ -0,0 +1,291 @@
package openai
import "strings"
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if state.disableDeltas {
return nil
}
captured := state.capture.String()
if captured == "" {
return nil
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return nil
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return nil
}
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
if hasMultiple {
state.disableDeltas = true
return nil
}
if !certainSingle {
// In uncertain phases (e.g. first call arrived but array not closed yet),
// avoid speculative deltas and wait for final parsed tool_calls payload.
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
}
deltas := make([]toolCallDelta, 0, 2)
if state.toolName == "" {
name, ok := extractToolCallName(captured, callStart)
if !ok || name == "" {
return nil
}
state.toolName = name
}
if state.toolArgsStart < 0 {
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
if ok {
state.toolArgsString = stringMode
if stringMode {
state.toolArgsStart = argsStart + 1
} else {
state.toolArgsStart = argsStart
}
state.toolArgsSent = state.toolArgsStart
}
}
if !state.toolNameSent {
if state.toolArgsStart < 0 {
return nil
}
state.toolNameSent = true
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
}
if state.toolArgsStart < 0 || state.toolArgsDone {
return deltas
}
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
if !ok {
return deltas
}
if end > state.toolArgsSent {
deltas = append(deltas, toolCallDelta{
Index: 0,
Arguments: captured[state.toolArgsSent:end],
})
state.toolArgsSent = end
}
if complete {
state.toolArgsDone = true
}
return deltas
}
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return false, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return false, false
}
count := 0
depth := 0
quote := byte(0)
escaped := false
for ; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
if depth == 0 {
count++
if count > 1 {
return false, true
}
}
depth++
continue
}
if ch == '}' {
if depth > 0 {
depth--
}
continue
}
if ch == ',' && depth == 0 {
// top-level separator means at least one more tool call exists
// (or is expected). Treat as multi-call and stop incremental deltas.
return false, true
}
if ch == ']' && depth == 0 {
return count == 1, false
}
}
// array not closed yet: still uncertain whether more calls will appear
return false, false
}
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return -1, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return -1, false
}
return i, true
}
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
i := keyIdx + len("tool_calls")
for i < len(text) && text[i] != ':' {
i++
}
if i >= len(text) {
return -1, false
}
i = skipSpaces(text, i+1)
if i >= len(text) || text[i] != '[' {
return -1, false
}
return i, true
}
func extractToolCallName(text string, callStart int) (string, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return "", false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
return "", false
}
}
name, _, ok := parseJSONStringLiteral(text, valueStart)
if !ok {
return "", false
}
return name, true
}
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
keys := []string{"input", "arguments", "args", "parameters", "params"}
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
if !ok {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return -1, false, false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
if !ok {
return -1, false, false
}
}
if valueStart >= len(text) {
return -1, false, false
}
ch := text[valueStart]
if ch == '{' || ch == '[' {
return valueStart, false, true
}
if ch == '"' {
return valueStart, true, true
}
return -1, false, false
}
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
if start < 0 || start > len(text) {
return 0, false, false
}
if stringMode {
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return i, true, true
}
}
return len(text), false, true
}
if start >= len(text) {
return start, false, false
}
if text[start] != '{' && text[start] != '[' {
return 0, false, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' || ch == '[' {
depth++
continue
}
if ch == '}' || ch == ']' {
depth--
if depth == 0 {
return i + 1, true, true
}
}
}
return len(text), false, true
}
func findFunctionObjectStart(text string, callStart int) (int, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
return -1, false
}
return valueStart, true
}

View File

@@ -0,0 +1,152 @@
package openai
import "strings"
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
end := i + 1
return text[start:end], end, true
}
}
}
return "", 0, false
}
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
return 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := objStart; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
if depth == 1 {
key, end, ok := parseJSONStringLiteral(text, i)
if !ok {
return 0, false
}
j := skipSpaces(text, end)
if j >= len(text) || text[j] != ':' {
i = end - 1
continue
}
j = skipSpaces(text, j+1)
if j >= len(text) {
return 0, false
}
if containsKey(keys, key) {
return j, true
}
i = j - 1
continue
}
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
break
}
}
}
return 0, false
}
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '"' {
return "", 0, false
}
var b strings.Builder
escaped := false
for i := start + 1; i < len(text); i++ {
ch := text[i]
if escaped {
b.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return b.String(), i + 1, true
}
b.WriteByte(ch)
}
return "", 0, false
}
func containsKey(keys []string, value string) bool {
for _, k := range keys {
if k == value {
return true
}
}
return false
}
func skipSpaces(text string, i int) int {
for i < len(text) {
switch text[i] {
case ' ', '\t', '\n', '\r':
i++
default:
return i
}
}
return i
}

View File

@@ -0,0 +1,75 @@
package openai
import (
"strings"
"ds2api/internal/util"
)
type toolStreamSieveState struct {
pending strings.Builder
capture strings.Builder
capturing bool
recentTextTail string
disableDeltas bool
toolNameSent bool
toolName string
toolArgsStart int
toolArgsSent int
toolArgsString bool
toolArgsDone bool
}
type toolStreamEvent struct {
Content string
ToolCalls []util.ParsedToolCall
ToolCallDeltas []toolCallDelta
}
type toolCallDelta struct {
Index int
Name string
Arguments string
}
const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.disableDeltas = false
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
s.toolArgsSent = -1
s.toolArgsString = false
s.toolArgsDone = false
}
func (s *toolStreamSieveState) noteText(content string) {
if strings.TrimSpace(content) == "" {
return
}
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
}
func appendTail(prev, next string, max int) string {
if max <= 0 {
return ""
}
combined := prev + next
if len(combined) <= max {
return combined
}
return combined[len(combined)-max:]
}
func looksLikeToolExampleContext(text string) bool {
return insideCodeFence(text)
}
func insideCodeFence(text string) bool {
if text == "" {
return false
}
return strings.Count(text, "```")%2 == 1
}

View File

@@ -0,0 +1,114 @@
package admin
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"ds2api/internal/config"
)
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
page := intFromQuery(r, "page", 1)
pageSize := intFromQuery(r, "page_size", 10)
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 1
}
if pageSize > 100 {
pageSize = 100
}
accounts := h.Store.Snapshot().Accounts
total := len(accounts)
reverseAccounts(accounts)
totalPages := 1
if total > 0 {
totalPages = (total + pageSize - 1) / pageSize
}
start := (page - 1) * pageSize
if start > total {
start = total
}
end := start + pageSize
if end > total {
end = total
}
items := make([]map[string]any, 0, end-start)
for _, acc := range accounts[start:end] {
token := strings.TrimSpace(acc.Token)
preview := ""
if token != "" {
if len(token) > 20 {
preview = token[:20] + "..."
} else {
preview = token
}
}
items = append(items, map[string]any{
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"has_password": acc.Password != "",
"has_token": token != "",
"token_preview": preview,
})
}
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
}
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
acc := toAccount(req)
if acc.Identifier() == "" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
return
}
err := h.Store.Update(func(c *config.Config) error {
for _, a := range c.Accounts {
if acc.Email != "" && a.Email == acc.Email {
return fmt.Errorf("邮箱已存在")
}
if acc.Mobile != "" && a.Mobile == acc.Mobile {
return fmt.Errorf("手机号已存在")
}
}
c.Accounts = append(c.Accounts, acc)
return nil
})
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
}
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
identifier := chi.URLParam(r, "identifier")
err := h.Store.Update(func(c *config.Config) error {
idx := -1
for i, a := range c.Accounts {
if accountMatchesIdentifier(a, identifier) {
idx = i
break
}
}
if idx < 0 {
return fmt.Errorf("账号不存在")
}
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
return nil
})
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
}

View File

@@ -0,0 +1,7 @@
package admin
import "net/http"
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, h.Pool.Status())
}

View File

@@ -11,119 +11,11 @@ import (
"sync"
"time"
"github.com/go-chi/chi/v5"
authn "ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/sse"
)
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
page := intFromQuery(r, "page", 1)
pageSize := intFromQuery(r, "page_size", 10)
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 1
}
if pageSize > 100 {
pageSize = 100
}
accounts := h.Store.Snapshot().Accounts
total := len(accounts)
reverseAccounts(accounts)
totalPages := 1
if total > 0 {
totalPages = (total + pageSize - 1) / pageSize
}
start := (page - 1) * pageSize
if start > total {
start = total
}
end := start + pageSize
if end > total {
end = total
}
items := make([]map[string]any, 0, end-start)
for _, acc := range accounts[start:end] {
token := strings.TrimSpace(acc.Token)
preview := ""
if token != "" {
if len(token) > 20 {
preview = token[:20] + "..."
} else {
preview = token
}
}
items = append(items, map[string]any{
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"has_password": acc.Password != "",
"has_token": token != "",
"token_preview": preview,
})
}
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
}
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
acc := toAccount(req)
if acc.Identifier() == "" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
return
}
err := h.Store.Update(func(c *config.Config) error {
for _, a := range c.Accounts {
if acc.Email != "" && a.Email == acc.Email {
return fmt.Errorf("邮箱已存在")
}
if acc.Mobile != "" && a.Mobile == acc.Mobile {
return fmt.Errorf("手机号已存在")
}
}
c.Accounts = append(c.Accounts, acc)
return nil
})
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
}
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
identifier := chi.URLParam(r, "identifier")
err := h.Store.Update(func(c *config.Config) error {
idx := -1
for i, a := range c.Accounts {
if accountMatchesIdentifier(a, identifier) {
idx = i
break
}
}
if idx < 0 {
return fmt.Errorf("账号不存在")
}
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
return nil
})
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
}
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, h.Pool.Status())
}
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)

View File

@@ -1,393 +0,0 @@
package admin
import (
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"ds2api/internal/config"
)
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
snap := h.Store.Snapshot()
safe := map[string]any{
"keys": snap.Keys,
"accounts": []map[string]any{},
"claude_mapping": func() map[string]string {
if len(snap.ClaudeMapping) > 0 {
return snap.ClaudeMapping
}
return snap.ClaudeModelMap
}(),
}
accounts := make([]map[string]any, 0, len(snap.Accounts))
for _, acc := range snap.Accounts {
token := strings.TrimSpace(acc.Token)
preview := ""
if token != "" {
if len(token) > 20 {
preview = token[:20] + "..."
} else {
preview = token
}
}
accounts = append(accounts, map[string]any{
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"has_password": strings.TrimSpace(acc.Password) != "",
"has_token": token != "",
"token_preview": preview,
})
}
safe["accounts"] = accounts
writeJSON(w, http.StatusOK, safe)
}
func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
old := h.Store.Snapshot()
err := h.Store.Update(func(c *config.Config) error {
if keys, ok := toStringSlice(req["keys"]); ok {
c.Keys = keys
}
if accountsRaw, ok := req["accounts"].([]any); ok {
existing := map[string]config.Account{}
for _, a := range old.Accounts {
existing[a.Identifier()] = a
}
accounts := make([]config.Account, 0, len(accountsRaw))
for _, item := range accountsRaw {
m, ok := item.(map[string]any)
if !ok {
continue
}
acc := toAccount(m)
id := acc.Identifier()
if prev, ok := existing[id]; ok {
if strings.TrimSpace(acc.Password) == "" {
acc.Password = prev.Password
}
if strings.TrimSpace(acc.Token) == "" {
acc.Token = prev.Token
}
}
accounts = append(accounts, acc)
}
c.Accounts = accounts
}
if m, ok := req["claude_mapping"].(map[string]any); ok {
newMap := map[string]string{}
for k, v := range m {
newMap[k] = fmt.Sprintf("%v", v)
}
c.ClaudeMapping = newMap
}
return nil
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
}
func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
key, _ := req["key"].(string)
key = strings.TrimSpace(key)
if key == "" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
return
}
err := h.Store.Update(func(c *config.Config) error {
for _, k := range c.Keys {
if k == key {
return fmt.Errorf("Key 已存在")
}
}
c.Keys = append(c.Keys, key)
return nil
})
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
}
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
key := chi.URLParam(r, "key")
err := h.Store.Update(func(c *config.Config) error {
idx := -1
for i, k := range c.Keys {
if k == key {
idx = i
break
}
}
if idx < 0 {
return fmt.Errorf("Key 不存在")
}
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
return nil
})
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
}
func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
return
}
importedKeys, importedAccounts := 0, 0
err := h.Store.Update(func(c *config.Config) error {
if keys, ok := req["keys"].([]any); ok {
existing := map[string]bool{}
for _, k := range c.Keys {
existing[k] = true
}
for _, k := range keys {
key := strings.TrimSpace(fmt.Sprintf("%v", k))
if key == "" || existing[key] {
continue
}
c.Keys = append(c.Keys, key)
existing[key] = true
importedKeys++
}
}
if accounts, ok := req["accounts"].([]any); ok {
existing := map[string]bool{}
for _, a := range c.Accounts {
existing[a.Identifier()] = true
}
for _, item := range accounts {
m, ok := item.(map[string]any)
if !ok {
continue
}
acc := toAccount(m)
id := acc.Identifier()
if id == "" || existing[id] {
continue
}
c.Accounts = append(c.Accounts, acc)
existing[id] = true
importedAccounts++
}
}
return nil
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
}
func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
h.configExport(w, nil)
}
func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
snap := h.Store.Snapshot()
jsonStr, b64, err := h.Store.ExportJSONAndBase64()
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"config": snap,
"json": jsonStr,
"base64": b64,
})
}
func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
if mode == "" {
mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
}
if mode == "" {
mode = "merge"
}
if mode != "merge" && mode != "replace" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
return
}
payload := req
if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
payload = raw
}
rawJSON, err := json.Marshal(payload)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
return
}
var incoming config.Config
if err := json.Unmarshal(rawJSON, &incoming); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
importedKeys, importedAccounts := 0, 0
err = h.Store.Update(func(c *config.Config) error {
next := c.Clone()
if mode == "replace" {
next = incoming.Clone()
next.VercelSyncHash = c.VercelSyncHash
next.VercelSyncTime = c.VercelSyncTime
importedKeys = len(next.Keys)
importedAccounts = len(next.Accounts)
} else {
existingKeys := map[string]struct{}{}
for _, k := range next.Keys {
existingKeys[k] = struct{}{}
}
for _, k := range incoming.Keys {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if _, ok := existingKeys[key]; ok {
continue
}
existingKeys[key] = struct{}{}
next.Keys = append(next.Keys, key)
importedKeys++
}
existingAccounts := map[string]struct{}{}
for _, acc := range next.Accounts {
existingAccounts[acc.Identifier()] = struct{}{}
}
for _, acc := range incoming.Accounts {
id := acc.Identifier()
if id == "" {
continue
}
if _, ok := existingAccounts[id]; ok {
continue
}
existingAccounts[id] = struct{}{}
next.Accounts = append(next.Accounts, acc)
importedAccounts++
}
if len(incoming.ClaudeMapping) > 0 {
if next.ClaudeMapping == nil {
next.ClaudeMapping = map[string]string{}
}
for k, v := range incoming.ClaudeMapping {
next.ClaudeMapping[k] = v
}
}
if len(incoming.ClaudeModelMap) > 0 {
if next.ClaudeModelMap == nil {
next.ClaudeModelMap = map[string]string{}
}
for k, v := range incoming.ClaudeModelMap {
next.ClaudeModelMap[k] = v
}
}
if len(incoming.ModelAliases) > 0 {
if next.ModelAliases == nil {
next.ModelAliases = map[string]string{}
}
for k, v := range incoming.ModelAliases {
next.ModelAliases[k] = v
}
}
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
next.Toolcall.Mode = incoming.Toolcall.Mode
}
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
}
if incoming.Responses.StoreTTLSeconds > 0 {
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
}
if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
next.Embeddings.Provider = incoming.Embeddings.Provider
}
if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
next.Admin.PasswordHash = incoming.Admin.PasswordHash
}
if incoming.Admin.JWTExpireHours > 0 {
next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
}
if incoming.Admin.JWTValidAfterUnix > 0 {
next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
}
if incoming.Runtime.AccountMaxInflight > 0 {
next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
}
if incoming.Runtime.AccountMaxQueue > 0 {
next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
}
if incoming.Runtime.GlobalMaxInflight > 0 {
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
}
}
normalizeSettingsConfig(&next)
if err := validateSettingsConfig(next); err != nil {
return newRequestError(err.Error())
}
*c = next
return nil
})
if err != nil {
if detail, ok := requestErrorDetail(err); ok {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
return
}
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"mode": mode,
"imported_keys": importedKeys,
"imported_accounts": importedAccounts,
"message": "config imported",
})
}
func (h *Handler) computeSyncHash() string {
snap := h.Store.Snapshot().Clone()
snap.VercelSyncHash = ""
snap.VercelSyncTime = 0
b, _ := json.Marshal(snap)
sum := md5.Sum(b)
return fmt.Sprintf("%x", sum)
}

View File

@@ -0,0 +1,182 @@
package admin
import (
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"strings"
"ds2api/internal/config"
)
func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
if mode == "" {
mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
}
if mode == "" {
mode = "merge"
}
if mode != "merge" && mode != "replace" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
return
}
payload := req
if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
payload = raw
}
rawJSON, err := json.Marshal(payload)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
return
}
var incoming config.Config
if err := json.Unmarshal(rawJSON, &incoming); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
importedKeys, importedAccounts := 0, 0
err = h.Store.Update(func(c *config.Config) error {
next := c.Clone()
if mode == "replace" {
next = incoming.Clone()
next.VercelSyncHash = c.VercelSyncHash
next.VercelSyncTime = c.VercelSyncTime
importedKeys = len(next.Keys)
importedAccounts = len(next.Accounts)
} else {
existingKeys := map[string]struct{}{}
for _, k := range next.Keys {
existingKeys[k] = struct{}{}
}
for _, k := range incoming.Keys {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if _, ok := existingKeys[key]; ok {
continue
}
existingKeys[key] = struct{}{}
next.Keys = append(next.Keys, key)
importedKeys++
}
existingAccounts := map[string]struct{}{}
for _, acc := range next.Accounts {
existingAccounts[acc.Identifier()] = struct{}{}
}
for _, acc := range incoming.Accounts {
id := acc.Identifier()
if id == "" {
continue
}
if _, ok := existingAccounts[id]; ok {
continue
}
existingAccounts[id] = struct{}{}
next.Accounts = append(next.Accounts, acc)
importedAccounts++
}
if len(incoming.ClaudeMapping) > 0 {
if next.ClaudeMapping == nil {
next.ClaudeMapping = map[string]string{}
}
for k, v := range incoming.ClaudeMapping {
next.ClaudeMapping[k] = v
}
}
if len(incoming.ClaudeModelMap) > 0 {
if next.ClaudeModelMap == nil {
next.ClaudeModelMap = map[string]string{}
}
for k, v := range incoming.ClaudeModelMap {
next.ClaudeModelMap[k] = v
}
}
if len(incoming.ModelAliases) > 0 {
if next.ModelAliases == nil {
next.ModelAliases = map[string]string{}
}
for k, v := range incoming.ModelAliases {
next.ModelAliases[k] = v
}
}
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
next.Toolcall.Mode = incoming.Toolcall.Mode
}
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
}
if incoming.Responses.StoreTTLSeconds > 0 {
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
}
if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
next.Embeddings.Provider = incoming.Embeddings.Provider
}
if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
next.Admin.PasswordHash = incoming.Admin.PasswordHash
}
if incoming.Admin.JWTExpireHours > 0 {
next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
}
if incoming.Admin.JWTValidAfterUnix > 0 {
next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
}
if incoming.Runtime.AccountMaxInflight > 0 {
next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
}
if incoming.Runtime.AccountMaxQueue > 0 {
next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
}
if incoming.Runtime.GlobalMaxInflight > 0 {
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
}
}
normalizeSettingsConfig(&next)
if err := validateSettingsConfig(next); err != nil {
return newRequestError(err.Error())
}
*c = next
return nil
})
if err != nil {
if detail, ok := requestErrorDetail(err); ok {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
return
}
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"mode": mode,
"imported_keys": importedKeys,
"imported_accounts": importedAccounts,
"message": "config imported",
})
}
func (h *Handler) computeSyncHash() string {
snap := h.Store.Snapshot().Clone()
snap.VercelSyncHash = ""
snap.VercelSyncTime = 0
b, _ := json.Marshal(snap)
sum := md5.Sum(b)
return fmt.Sprintf("%x", sum)
}

View File

@@ -0,0 +1,61 @@
package admin
import (
"net/http"
"strings"
)
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
snap := h.Store.Snapshot()
safe := map[string]any{
"keys": snap.Keys,
"accounts": []map[string]any{},
"claude_mapping": func() map[string]string {
if len(snap.ClaudeMapping) > 0 {
return snap.ClaudeMapping
}
return snap.ClaudeModelMap
}(),
}
accounts := make([]map[string]any, 0, len(snap.Accounts))
for _, acc := range snap.Accounts {
token := strings.TrimSpace(acc.Token)
preview := ""
if token != "" {
if len(token) > 20 {
preview = token[:20] + "..."
} else {
preview = token
}
}
accounts = append(accounts, map[string]any{
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"has_password": strings.TrimSpace(acc.Password) != "",
"has_token": token != "",
"token_preview": preview,
})
}
safe["accounts"] = accounts
writeJSON(w, http.StatusOK, safe)
}
func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
h.configExport(w, nil)
}
func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
snap := h.Store.Snapshot()
jsonStr, b64, err := h.Store.ExportJSONAndBase64()
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"config": snap,
"json": jsonStr,
"base64": b64,
})
}

View File

@@ -0,0 +1,166 @@
package admin
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"ds2api/internal/config"
)
func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
old := h.Store.Snapshot()
err := h.Store.Update(func(c *config.Config) error {
if keys, ok := toStringSlice(req["keys"]); ok {
c.Keys = keys
}
if accountsRaw, ok := req["accounts"].([]any); ok {
existing := map[string]config.Account{}
for _, a := range old.Accounts {
existing[a.Identifier()] = a
}
accounts := make([]config.Account, 0, len(accountsRaw))
for _, item := range accountsRaw {
m, ok := item.(map[string]any)
if !ok {
continue
}
acc := toAccount(m)
id := acc.Identifier()
if prev, ok := existing[id]; ok {
if strings.TrimSpace(acc.Password) == "" {
acc.Password = prev.Password
}
if strings.TrimSpace(acc.Token) == "" {
acc.Token = prev.Token
}
}
accounts = append(accounts, acc)
}
c.Accounts = accounts
}
if m, ok := req["claude_mapping"].(map[string]any); ok {
newMap := map[string]string{}
for k, v := range m {
newMap[k] = fmt.Sprintf("%v", v)
}
c.ClaudeMapping = newMap
}
return nil
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
}
func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
key, _ := req["key"].(string)
key = strings.TrimSpace(key)
if key == "" {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
return
}
err := h.Store.Update(func(c *config.Config) error {
for _, k := range c.Keys {
if k == key {
return fmt.Errorf("Key 已存在")
}
}
c.Keys = append(c.Keys, key)
return nil
})
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
}
func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
key := chi.URLParam(r, "key")
err := h.Store.Update(func(c *config.Config) error {
idx := -1
for i, k := range c.Keys {
if k == key {
idx = i
break
}
}
if idx < 0 {
return fmt.Errorf("Key 不存在")
}
c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
return nil
})
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
}
func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
return
}
importedKeys, importedAccounts := 0, 0
err := h.Store.Update(func(c *config.Config) error {
if keys, ok := req["keys"].([]any); ok {
existing := map[string]bool{}
for _, k := range c.Keys {
existing[k] = true
}
for _, k := range keys {
key := strings.TrimSpace(fmt.Sprintf("%v", k))
if key == "" || existing[key] {
continue
}
c.Keys = append(c.Keys, key)
existing[key] = true
importedKeys++
}
}
if accounts, ok := req["accounts"].([]any); ok {
existing := map[string]bool{}
for _, a := range c.Accounts {
existing[a.Identifier()] = true
}
for _, item := range accounts {
m, ok := item.(map[string]any)
if !ok {
continue
}
acc := toAccount(m)
id := acc.Identifier()
if id == "" || existing[id] {
continue
}
c.Accounts = append(c.Accounts, acc)
existing[id] = true
importedAccounts++
}
}
return nil
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
}

View File

@@ -1,321 +0,0 @@
package admin
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
authn "ds2api/internal/auth"
"ds2api/internal/config"
)
func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
snap := h.Store.Snapshot()
recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"admin": map[string]any{
"has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
"jwt_expire_hours": h.Store.AdminJWTExpireHours(),
"jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
"default_password_warning": authn.UsingDefaultAdminKey(h.Store),
},
"runtime": map[string]any{
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
},
"toolcall": snap.Toolcall,
"responses": snap.Responses,
"embeddings": snap.Embeddings,
"claude_mapping": settingsClaudeMapping(snap),
"model_aliases": snap.ModelAliases,
"env_backed": h.Store.IsEnvBacked(),
"needs_vercel_sync": needsSync,
})
}
func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
if runtimeCfg != nil {
if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
}
if err := h.Store.Update(func(c *config.Config) error {
if adminCfg != nil {
if adminCfg.JWTExpireHours > 0 {
c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
}
}
if runtimeCfg != nil {
if runtimeCfg.AccountMaxInflight > 0 {
c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
}
if runtimeCfg.AccountMaxQueue > 0 {
c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
}
if runtimeCfg.GlobalMaxInflight > 0 {
c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
}
}
if toolcallCfg != nil {
if strings.TrimSpace(toolcallCfg.Mode) != "" {
c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
}
if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
}
}
if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
}
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
}
if claudeMap != nil {
c.ClaudeMapping = claudeMap
c.ClaudeModelMap = nil
}
if aliasMap != nil {
c.ModelAliases = aliasMap
}
return nil
}); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.applyRuntimeSettings()
needsSync := config.IsVercel() || h.Store.IsEnvBacked()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"message": "settings updated and hot reloaded",
"env_backed": h.Store.IsEnvBacked(),
"needs_vercel_sync": needsSync,
"manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
})
}
func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
merged := current
if incoming != nil {
if incoming.AccountMaxInflight > 0 {
merged.AccountMaxInflight = incoming.AccountMaxInflight
}
if incoming.AccountMaxQueue > 0 {
merged.AccountMaxQueue = incoming.AccountMaxQueue
}
if incoming.GlobalMaxInflight > 0 {
merged.GlobalMaxInflight = incoming.GlobalMaxInflight
}
}
return validateRuntimeSettings(merged)
}
func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
newPassword := strings.TrimSpace(fieldString(req, "new_password"))
if newPassword == "" {
newPassword = strings.TrimSpace(fieldString(req, "password"))
}
if len(newPassword) < 4 {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
return
}
now := time.Now().Unix()
hash := authn.HashAdminPassword(newPassword)
if err := h.Store.Update(func(c *config.Config) error {
c.Admin.PasswordHash = hash
c.Admin.JWTValidAfterUnix = now
return nil
}); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"message": "password updated",
"force_relogin": true,
"jwt_valid_after_unix": now,
})
}
func (h *Handler) applyRuntimeSettings() {
if h == nil || h.Store == nil || h.Pool == nil {
return
}
accountCount := len(h.Store.Accounts())
maxPer := h.Store.RuntimeAccountMaxInflight()
recommended := defaultRuntimeRecommended(accountCount, maxPer)
maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
global := h.Store.RuntimeGlobalMaxInflight(recommended)
h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
}
func defaultRuntimeRecommended(accountCount, maxPer int) int {
if maxPer <= 0 {
maxPer = 1
}
if accountCount <= 0 {
return maxPer
}
return accountCount * maxPer
}
func settingsClaudeMapping(c config.Config) map[string]string {
if len(c.ClaudeMapping) > 0 {
return c.ClaudeMapping
}
if len(c.ClaudeModelMap) > 0 {
return c.ClaudeModelMap
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
var (
adminCfg *config.AdminConfig
runtimeCfg *config.RuntimeConfig
toolcallCfg *config.ToolcallConfig
respCfg *config.ResponsesConfig
embCfg *config.EmbeddingsConfig
claudeMap map[string]string
aliasMap map[string]string
)
if raw, ok := req["admin"].(map[string]any); ok {
cfg := &config.AdminConfig{}
if v, exists := raw["jwt_expire_hours"]; exists {
n := intFrom(v)
if n < 1 || n > 720 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
}
cfg.JWTExpireHours = n
}
adminCfg = cfg
}
if raw, ok := req["runtime"].(map[string]any); ok {
cfg := &config.RuntimeConfig{}
if v, exists := raw["account_max_inflight"]; exists {
n := intFrom(v)
if n < 1 || n > 256 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
}
cfg.AccountMaxInflight = n
}
if v, exists := raw["account_max_queue"]; exists {
n := intFrom(v)
if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
}
cfg.AccountMaxQueue = n
}
if v, exists := raw["global_max_inflight"]; exists {
n := intFrom(v)
if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
}
cfg.GlobalMaxInflight = n
}
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
}
runtimeCfg = cfg
}
if raw, ok := req["toolcall"].(map[string]any); ok {
cfg := &config.ToolcallConfig{}
if v, exists := raw["mode"]; exists {
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
switch mode {
case "feature_match", "off":
cfg.Mode = mode
default:
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
}
}
if v, exists := raw["early_emit_confidence"]; exists {
level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
switch level {
case "high", "low", "off":
cfg.EarlyEmitConfidence = level
default:
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
}
}
toolcallCfg = cfg
}
if raw, ok := req["responses"].(map[string]any); ok {
cfg := &config.ResponsesConfig{}
if v, exists := raw["store_ttl_seconds"]; exists {
n := intFrom(v)
if n < 30 || n > 86400 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
}
cfg.StoreTTLSeconds = n
}
respCfg = cfg
}
if raw, ok := req["embeddings"].(map[string]any); ok {
cfg := &config.EmbeddingsConfig{}
if v, exists := raw["provider"]; exists {
p := strings.TrimSpace(fmt.Sprintf("%v", v))
if p == "" {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
}
cfg.Provider = p
}
embCfg = cfg
}
if raw, ok := req["claude_mapping"].(map[string]any); ok {
claudeMap = map[string]string{}
for k, v := range raw {
key := strings.TrimSpace(k)
val := strings.TrimSpace(fmt.Sprintf("%v", v))
if key == "" || val == "" {
continue
}
claudeMap[key] = val
}
}
if raw, ok := req["model_aliases"].(map[string]any); ok {
aliasMap = map[string]string{}
for k, v := range raw {
key := strings.TrimSpace(k)
val := strings.TrimSpace(fmt.Sprintf("%v", v))
if key == "" || val == "" {
continue
}
aliasMap[key] = val
}
}
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
}

View File

@@ -0,0 +1,134 @@
package admin
import (
"fmt"
"strings"
"ds2api/internal/config"
)
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
var (
adminCfg *config.AdminConfig
runtimeCfg *config.RuntimeConfig
toolcallCfg *config.ToolcallConfig
respCfg *config.ResponsesConfig
embCfg *config.EmbeddingsConfig
claudeMap map[string]string
aliasMap map[string]string
)
if raw, ok := req["admin"].(map[string]any); ok {
cfg := &config.AdminConfig{}
if v, exists := raw["jwt_expire_hours"]; exists {
n := intFrom(v)
if n < 1 || n > 720 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
}
cfg.JWTExpireHours = n
}
adminCfg = cfg
}
if raw, ok := req["runtime"].(map[string]any); ok {
cfg := &config.RuntimeConfig{}
if v, exists := raw["account_max_inflight"]; exists {
n := intFrom(v)
if n < 1 || n > 256 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
}
cfg.AccountMaxInflight = n
}
if v, exists := raw["account_max_queue"]; exists {
n := intFrom(v)
if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
}
cfg.AccountMaxQueue = n
}
if v, exists := raw["global_max_inflight"]; exists {
n := intFrom(v)
if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
}
cfg.GlobalMaxInflight = n
}
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
}
runtimeCfg = cfg
}
if raw, ok := req["toolcall"].(map[string]any); ok {
cfg := &config.ToolcallConfig{}
if v, exists := raw["mode"]; exists {
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
switch mode {
case "feature_match", "off":
cfg.Mode = mode
default:
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
}
}
if v, exists := raw["early_emit_confidence"]; exists {
level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
switch level {
case "high", "low", "off":
cfg.EarlyEmitConfidence = level
default:
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
}
}
toolcallCfg = cfg
}
if raw, ok := req["responses"].(map[string]any); ok {
cfg := &config.ResponsesConfig{}
if v, exists := raw["store_ttl_seconds"]; exists {
n := intFrom(v)
if n < 30 || n > 86400 {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
}
cfg.StoreTTLSeconds = n
}
respCfg = cfg
}
if raw, ok := req["embeddings"].(map[string]any); ok {
cfg := &config.EmbeddingsConfig{}
if v, exists := raw["provider"]; exists {
p := strings.TrimSpace(fmt.Sprintf("%v", v))
if p == "" {
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
}
cfg.Provider = p
}
embCfg = cfg
}
if raw, ok := req["claude_mapping"].(map[string]any); ok {
claudeMap = map[string]string{}
for k, v := range raw {
key := strings.TrimSpace(k)
val := strings.TrimSpace(fmt.Sprintf("%v", v))
if key == "" || val == "" {
continue
}
claudeMap[key] = val
}
}
if raw, ok := req["model_aliases"].(map[string]any); ok {
aliasMap = map[string]string{}
for k, v := range raw {
key := strings.TrimSpace(k)
val := strings.TrimSpace(fmt.Sprintf("%v", v))
if key == "" || val == "" {
continue
}
aliasMap[key] = val
}
}
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
}

View File

@@ -0,0 +1,36 @@
package admin
import (
"net/http"
"strings"
authn "ds2api/internal/auth"
"ds2api/internal/config"
)
func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
snap := h.Store.Snapshot()
recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"admin": map[string]any{
"has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
"jwt_expire_hours": h.Store.AdminJWTExpireHours(),
"jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
"default_password_warning": authn.UsingDefaultAdminKey(h.Store),
},
"runtime": map[string]any{
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
},
"toolcall": snap.Toolcall,
"responses": snap.Responses,
"embeddings": snap.Embeddings,
"claude_mapping": settingsClaudeMapping(snap),
"model_aliases": snap.ModelAliases,
"env_backed": h.Store.IsEnvBacked(),
"needs_vercel_sync": needsSync,
})
}

View File

@@ -0,0 +1,51 @@
package admin
import "ds2api/internal/config"
func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
merged := current
if incoming != nil {
if incoming.AccountMaxInflight > 0 {
merged.AccountMaxInflight = incoming.AccountMaxInflight
}
if incoming.AccountMaxQueue > 0 {
merged.AccountMaxQueue = incoming.AccountMaxQueue
}
if incoming.GlobalMaxInflight > 0 {
merged.GlobalMaxInflight = incoming.GlobalMaxInflight
}
}
return validateRuntimeSettings(merged)
}
func (h *Handler) applyRuntimeSettings() {
if h == nil || h.Store == nil || h.Pool == nil {
return
}
accountCount := len(h.Store.Accounts())
maxPer := h.Store.RuntimeAccountMaxInflight()
recommended := defaultRuntimeRecommended(accountCount, maxPer)
maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
global := h.Store.RuntimeGlobalMaxInflight(recommended)
h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
}
func defaultRuntimeRecommended(accountCount, maxPer int) int {
if maxPer <= 0 {
maxPer = 1
}
if accountCount <= 0 {
return maxPer
}
return accountCount * maxPer
}
func settingsClaudeMapping(c config.Config) map[string]string {
if len(c.ClaudeMapping) > 0 {
return c.ClaudeMapping
}
if len(c.ClaudeModelMap) > 0 {
return c.ClaudeModelMap
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}

View File

@@ -0,0 +1,119 @@
package admin
import (
"encoding/json"
"net/http"
"strings"
"time"
authn "ds2api/internal/auth"
"ds2api/internal/config"
)
func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
if runtimeCfg != nil {
if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
}
if err := h.Store.Update(func(c *config.Config) error {
if adminCfg != nil {
if adminCfg.JWTExpireHours > 0 {
c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
}
}
if runtimeCfg != nil {
if runtimeCfg.AccountMaxInflight > 0 {
c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
}
if runtimeCfg.AccountMaxQueue > 0 {
c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
}
if runtimeCfg.GlobalMaxInflight > 0 {
c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
}
}
if toolcallCfg != nil {
if strings.TrimSpace(toolcallCfg.Mode) != "" {
c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
}
if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
}
}
if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
}
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
}
if claudeMap != nil {
c.ClaudeMapping = claudeMap
c.ClaudeModelMap = nil
}
if aliasMap != nil {
c.ModelAliases = aliasMap
}
return nil
}); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
h.applyRuntimeSettings()
needsSync := config.IsVercel() || h.Store.IsEnvBacked()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"message": "settings updated and hot reloaded",
"env_backed": h.Store.IsEnvBacked(),
"needs_vercel_sync": needsSync,
"manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
})
}
func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
return
}
newPassword := strings.TrimSpace(fieldString(req, "new_password"))
if newPassword == "" {
newPassword = strings.TrimSpace(fieldString(req, "password"))
}
if len(newPassword) < 4 {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
return
}
now := time.Now().Unix()
hash := authn.HashAdminPassword(newPassword)
if err := h.Store.Update(func(c *config.Config) error {
c.Admin.PasswordHash = hash
c.Admin.JWTValidAfterUnix = now
return nil
}); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"message": "password updated",
"force_relogin": true,
"jwt_valid_after_unix": now,
})
}

View File

@@ -0,0 +1,24 @@
package config
import (
"crypto/sha256"
"encoding/hex"
"strings"
)
func (a Account) Identifier() string {
if strings.TrimSpace(a.Email) != "" {
return strings.TrimSpace(a.Email)
}
if strings.TrimSpace(a.Mobile) != "" {
return strings.TrimSpace(a.Mobile)
}
// Backward compatibility: old configs may contain token-only accounts.
// Use a stable non-sensitive synthetic id so they can still join the pool.
token := strings.TrimSpace(a.Token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return "token:" + hex.EncodeToString(sum[:8])
}

241
internal/config/codec.go Normal file
View File

@@ -0,0 +1,241 @@
package config
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
)
func (c Config) MarshalJSON() ([]byte, error) {
m := map[string]any{}
for k, v := range c.AdditionalFields {
m[k] = v
}
if len(c.Keys) > 0 {
m["keys"] = c.Keys
}
if len(c.Accounts) > 0 {
m["accounts"] = c.Accounts
}
if len(c.ClaudeMapping) > 0 {
m["claude_mapping"] = c.ClaudeMapping
}
if len(c.ClaudeModelMap) > 0 {
m["claude_model_mapping"] = c.ClaudeModelMap
}
if len(c.ModelAliases) > 0 {
m["model_aliases"] = c.ModelAliases
}
if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
m["admin"] = c.Admin
}
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
m["runtime"] = c.Runtime
}
if c.Compat.WideInputStrictOutput != nil {
m["compat"] = c.Compat
}
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
m["toolcall"] = c.Toolcall
}
if c.Responses.StoreTTLSeconds > 0 {
m["responses"] = c.Responses
}
if strings.TrimSpace(c.Embeddings.Provider) != "" {
m["embeddings"] = c.Embeddings
}
if c.VercelSyncHash != "" {
m["_vercel_sync_hash"] = c.VercelSyncHash
}
if c.VercelSyncTime != 0 {
m["_vercel_sync_time"] = c.VercelSyncTime
}
return json.Marshal(m)
}
func (c *Config) UnmarshalJSON(b []byte) error {
raw := map[string]json.RawMessage{}
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
c.AdditionalFields = map[string]any{}
for k, v := range raw {
switch k {
case "keys":
if err := json.Unmarshal(v, &c.Keys); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "accounts":
if err := json.Unmarshal(v, &c.Accounts); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_mapping":
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_model_mapping":
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "model_aliases":
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "admin":
if err := json.Unmarshal(v, &c.Admin); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "runtime":
if err := json.Unmarshal(v, &c.Runtime); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "compat":
if err := json.Unmarshal(v, &c.Compat); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "toolcall":
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "responses":
if err := json.Unmarshal(v, &c.Responses); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "embeddings":
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_hash":
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_time":
if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
default:
var anyVal any
if err := json.Unmarshal(v, &anyVal); err == nil {
c.AdditionalFields[k] = anyVal
}
}
}
return nil
}
func (c Config) Clone() Config {
clone := Config{
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),
Admin: c.Admin,
Runtime: c.Runtime,
Compat: CompatConfig{
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
},
Toolcall: c.Toolcall,
Responses: c.Responses,
Embeddings: c.Embeddings,
VercelSyncHash: c.VercelSyncHash,
VercelSyncTime: c.VercelSyncTime,
AdditionalFields: map[string]any{},
}
for k, v := range c.AdditionalFields {
clone.AdditionalFields[k] = v
}
return clone
}
func cloneStringMap(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
v := *in
return &v
}
func parseConfigString(raw string) (Config, error) {
var cfg Config
candidates := []string{raw}
if normalized := normalizeConfigInput(raw); normalized != raw {
candidates = append(candidates, normalized)
}
for _, candidate := range candidates {
if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
return cfg, nil
}
}
base64Input := candidates[len(candidates)-1]
decoded, err := decodeConfigBase64(base64Input)
if err != nil {
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
}
if err := json.Unmarshal(decoded, &cfg); err != nil {
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
}
return cfg, nil
}
func normalizeConfigInput(raw string) string {
normalized := strings.TrimSpace(raw)
if normalized == "" {
return normalized
}
for {
changed := false
if len(normalized) >= 2 {
first := normalized[0]
last := normalized[len(normalized)-1]
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
changed = true
}
}
if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
normalized = strings.TrimSpace(normalized[len("base64:"):])
changed = true
}
if !changed {
break
}
}
return strings.TrimSpace(normalized)
}
func decodeConfigBase64(raw string) ([]byte, error) {
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
var lastErr error
for _, enc := range encodings {
decoded, err := enc.DecodeString(raw)
if err == nil {
return decoded, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("base64 decode failed")
}

View File

@@ -1,63 +1,5 @@
package config
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
)
var Logger = newLogger()
func newLogger() *slog.Logger {
level := new(slog.LevelVar)
switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
case "DEBUG":
level.Set(slog.LevelDebug)
case "WARN":
level.Set(slog.LevelWarn)
case "ERROR":
level.Set(slog.LevelError)
default:
level.Set(slog.LevelInfo)
}
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
return slog.New(h)
}
type Account struct {
Email string `json:"email,omitempty"`
Mobile string `json:"mobile,omitempty"`
Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"`
}
func (a Account) Identifier() string {
if strings.TrimSpace(a.Email) != "" {
return strings.TrimSpace(a.Email)
}
if strings.TrimSpace(a.Mobile) != "" {
return strings.TrimSpace(a.Mobile)
}
// Backward compatibility: old configs may contain token-only accounts.
// Use a stable non-sensitive synthetic id so they can still join the pool.
token := strings.TrimSpace(a.Token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return "token:" + hex.EncodeToString(sum[:8])
}
type Config struct {
Keys []string `json:"keys,omitempty"`
Accounts []Account `json:"accounts,omitempty"`
@@ -75,6 +17,13 @@ type Config struct {
AdditionalFields map[string]any `json:"-"`
}
type Account struct {
Email string `json:"email,omitempty"`
Mobile string `json:"mobile,omitempty"`
Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"`
}
type CompatConfig struct {
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
}
@@ -103,641 +52,3 @@ type ResponsesConfig struct {
type EmbeddingsConfig struct {
Provider string `json:"provider,omitempty"`
}
func (c Config) MarshalJSON() ([]byte, error) {
m := map[string]any{}
for k, v := range c.AdditionalFields {
m[k] = v
}
if len(c.Keys) > 0 {
m["keys"] = c.Keys
}
if len(c.Accounts) > 0 {
m["accounts"] = c.Accounts
}
if len(c.ClaudeMapping) > 0 {
m["claude_mapping"] = c.ClaudeMapping
}
if len(c.ClaudeModelMap) > 0 {
m["claude_model_mapping"] = c.ClaudeModelMap
}
if len(c.ModelAliases) > 0 {
m["model_aliases"] = c.ModelAliases
}
if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
m["admin"] = c.Admin
}
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
m["runtime"] = c.Runtime
}
if c.Compat.WideInputStrictOutput != nil {
m["compat"] = c.Compat
}
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
m["toolcall"] = c.Toolcall
}
if c.Responses.StoreTTLSeconds > 0 {
m["responses"] = c.Responses
}
if strings.TrimSpace(c.Embeddings.Provider) != "" {
m["embeddings"] = c.Embeddings
}
if c.VercelSyncHash != "" {
m["_vercel_sync_hash"] = c.VercelSyncHash
}
if c.VercelSyncTime != 0 {
m["_vercel_sync_time"] = c.VercelSyncTime
}
return json.Marshal(m)
}
func (c *Config) UnmarshalJSON(b []byte) error {
raw := map[string]json.RawMessage{}
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
c.AdditionalFields = map[string]any{}
for k, v := range raw {
switch k {
case "keys":
if err := json.Unmarshal(v, &c.Keys); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "accounts":
if err := json.Unmarshal(v, &c.Accounts); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_mapping":
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_model_mapping":
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "model_aliases":
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "admin":
if err := json.Unmarshal(v, &c.Admin); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "runtime":
if err := json.Unmarshal(v, &c.Runtime); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "compat":
if err := json.Unmarshal(v, &c.Compat); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "toolcall":
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "responses":
if err := json.Unmarshal(v, &c.Responses); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "embeddings":
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_hash":
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_time":
if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
default:
var anyVal any
if err := json.Unmarshal(v, &anyVal); err == nil {
c.AdditionalFields[k] = anyVal
}
}
}
return nil
}
func (c Config) Clone() Config {
clone := Config{
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),
Admin: c.Admin,
Runtime: c.Runtime,
Compat: CompatConfig{
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
},
Toolcall: c.Toolcall,
Responses: c.Responses,
Embeddings: c.Embeddings,
VercelSyncHash: c.VercelSyncHash,
VercelSyncTime: c.VercelSyncTime,
AdditionalFields: map[string]any{},
}
for k, v := range c.AdditionalFields {
clone.AdditionalFields[k] = v
}
return clone
}
func cloneStringMap(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
v := *in
return &v
}
type Store struct {
mu sync.RWMutex
cfg Config
path string
fromEnv bool
keyMap map[string]struct{} // O(1) API key lookup index
accMap map[string]int // O(1) account lookup: identifier -> slice index
}
func BaseDir() string {
cwd, err := os.Getwd()
if err != nil {
return "."
}
return cwd
}
func IsVercel() bool {
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
}
func ResolvePath(envKey, defaultRel string) string {
raw := strings.TrimSpace(os.Getenv(envKey))
if raw != "" {
if filepath.IsAbs(raw) {
return raw
}
return filepath.Join(BaseDir(), raw)
}
return filepath.Join(BaseDir(), defaultRel)
}
func ConfigPath() string {
return ResolvePath("DS2API_CONFIG_PATH", "config.json")
}
func WASMPath() string {
return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
}
func StaticAdminDir() string {
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
}
func LoadStore() *Store {
cfg, fromEnv, err := loadConfig()
if err != nil {
Logger.Warn("[config] load failed", "error", err)
}
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
Logger.Warn("[config] empty config loaded")
}
s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
s.rebuildIndexes()
return s
}
// rebuildIndexes must be called with the lock already held (or during init).
func (s *Store) rebuildIndexes() {
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
for _, k := range s.cfg.Keys {
s.keyMap[k] = struct{}{}
}
s.accMap = make(map[string]int, len(s.cfg.Accounts))
for i, acc := range s.cfg.Accounts {
id := acc.Identifier()
if id != "" {
s.accMap[id] = i
}
}
}
func loadConfig() (Config, bool, error) {
rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
if rawCfg == "" {
rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
}
if rawCfg != "" {
cfg, err := parseConfigString(rawCfg)
return cfg, true, err
}
content, err := os.ReadFile(ConfigPath())
if err != nil {
if IsVercel() {
// Vercel one-click deploy may start without a writable/present config file.
// Keep an in-memory config so users can bootstrap via WebUI then sync env.
return Config{}, true, nil
}
return Config{}, false, err
}
var cfg Config
if err := json.Unmarshal(content, &cfg); err != nil {
return Config{}, false, err
}
if IsVercel() {
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
return cfg, true, nil
}
return cfg, false, nil
}
func parseConfigString(raw string) (Config, error) {
var cfg Config
candidates := []string{raw}
if normalized := normalizeConfigInput(raw); normalized != raw {
candidates = append(candidates, normalized)
}
for _, candidate := range candidates {
if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
return cfg, nil
}
}
base64Input := candidates[len(candidates)-1]
decoded, err := decodeConfigBase64(base64Input)
if err != nil {
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
}
if err := json.Unmarshal(decoded, &cfg); err != nil {
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
}
return cfg, nil
}
func normalizeConfigInput(raw string) string {
normalized := strings.TrimSpace(raw)
if normalized == "" {
return normalized
}
for {
changed := false
if len(normalized) >= 2 {
first := normalized[0]
last := normalized[len(normalized)-1]
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
changed = true
}
}
if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
normalized = strings.TrimSpace(normalized[len("base64:"):])
changed = true
}
if !changed {
break
}
}
return strings.TrimSpace(normalized)
}
func decodeConfigBase64(raw string) ([]byte, error) {
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
var lastErr error
for _, enc := range encodings {
decoded, err := enc.DecodeString(raw)
if err == nil {
return decoded, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("base64 decode failed")
}
func (s *Store) Snapshot() Config {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cfg.Clone()
}
func (s *Store) HasAPIKey(k string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.keyMap[k]
return ok
}
func (s *Store) Keys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.cfg.Keys)
}
func (s *Store) Accounts() []Account {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.cfg.Accounts)
}
func (s *Store) FindAccount(identifier string) (Account, bool) {
identifier = strings.TrimSpace(identifier)
s.mu.RLock()
defer s.mu.RUnlock()
if idx, ok := s.findAccountIndexLocked(identifier); ok {
return s.cfg.Accounts[idx], true
}
return Account{}, false
}
func (s *Store) UpdateAccountToken(identifier, token string) error {
identifier = strings.TrimSpace(identifier)
s.mu.Lock()
defer s.mu.Unlock()
idx, ok := s.findAccountIndexLocked(identifier)
if !ok {
return errors.New("account not found")
}
oldID := s.cfg.Accounts[idx].Identifier()
s.cfg.Accounts[idx].Token = token
newID := s.cfg.Accounts[idx].Identifier()
// Keep historical aliases usable for long-lived queues while also adding
// the latest identifier after token refresh.
if identifier != "" {
s.accMap[identifier] = idx
}
if oldID != "" {
s.accMap[oldID] = idx
}
if newID != "" {
s.accMap[newID] = idx
}
return s.saveLocked()
}
func (s *Store) Replace(cfg Config) error {
s.mu.Lock()
defer s.mu.Unlock()
s.cfg = cfg.Clone()
s.rebuildIndexes()
return s.saveLocked()
}
func (s *Store) Update(mutator func(*Config) error) error {
s.mu.Lock()
defer s.mu.Unlock()
cfg := s.cfg.Clone()
if err := mutator(&cfg); err != nil {
return err
}
s.cfg = cfg
s.rebuildIndexes()
return s.saveLocked()
}
func (s *Store) Save() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.fromEnv {
Logger.Info("[save_config] source from env, skip write")
return nil
}
b, err := json.MarshalIndent(s.cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, b, 0o644)
}
func (s *Store) saveLocked() error {
if s.fromEnv {
Logger.Info("[save_config] source from env, skip write")
return nil
}
b, err := json.MarshalIndent(s.cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, b, 0o644)
}
// findAccountIndexLocked expects the store lock to already be held.
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
return idx, true
}
// Fallback for token-only accounts whose derived identifier changed after
// a token refresh; this preserves correctness on map misses.
for i, acc := range s.cfg.Accounts {
if acc.Identifier() == identifier {
return i, true
}
}
return -1, false
}
func (s *Store) IsEnvBacked() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fromEnv
}
func (s *Store) SetVercelSync(hash string, ts int64) error {
return s.Update(func(c *Config) error {
c.VercelSyncHash = hash
c.VercelSyncTime = ts
return nil
})
}
func (s *Store) ExportJSONAndBase64() (string, string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
b, err := json.Marshal(s.cfg)
if err != nil {
return "", "", err
}
return string(b), base64.StdEncoding.EncodeToString(b), nil
}
func (s *Store) ClaudeMapping() map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.cfg.ClaudeModelMap) > 0 {
return cloneStringMap(s.cfg.ClaudeModelMap)
}
if len(s.cfg.ClaudeMapping) > 0 {
return cloneStringMap(s.cfg.ClaudeMapping)
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}
func (s *Store) ModelAliases() map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
out := DefaultModelAliases()
for k, v := range s.cfg.ModelAliases {
key := strings.TrimSpace(lower(k))
val := strings.TrimSpace(lower(v))
if key == "" || val == "" {
continue
}
out[key] = val
}
return out
}
func (s *Store) CompatWideInputStrictOutput() bool {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Compat.WideInputStrictOutput == nil {
return true
}
return *s.cfg.Compat.WideInputStrictOutput
}
func (s *Store) ToolcallMode() string {
s.mu.RLock()
defer s.mu.RUnlock()
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
if mode == "" {
return "feature_match"
}
return mode
}
func (s *Store) ToolcallEarlyEmitConfidence() string {
s.mu.RLock()
defer s.mu.RUnlock()
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
if level == "" {
return "high"
}
return level
}
func (s *Store) ResponsesStoreTTLSeconds() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Responses.StoreTTLSeconds > 0 {
return s.cfg.Responses.StoreTTLSeconds
}
return 900
}
func (s *Store) EmbeddingsProvider() string {
s.mu.RLock()
defer s.mu.RUnlock()
return strings.TrimSpace(s.cfg.Embeddings.Provider)
}
func (s *Store) AdminPasswordHash() string {
s.mu.RLock()
defer s.mu.RUnlock()
return strings.TrimSpace(s.cfg.Admin.PasswordHash)
}
func (s *Store) AdminJWTExpireHours() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Admin.JWTExpireHours > 0 {
return s.cfg.Admin.JWTExpireHours
}
if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
if n, err := strconv.Atoi(raw); err == nil && n > 0 {
return n
}
}
return 24
}
func (s *Store) AdminJWTValidAfterUnix() int64 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cfg.Admin.JWTValidAfterUnix
}
func (s *Store) RuntimeAccountMaxInflight() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Runtime.AccountMaxInflight > 0 {
return s.cfg.Runtime.AccountMaxInflight
}
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n > 0 {
return n
}
}
return 2
}
func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Runtime.AccountMaxQueue > 0 {
return s.cfg.Runtime.AccountMaxQueue
}
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n >= 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}
func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Runtime.GlobalMaxInflight > 0 {
return s.cfg.Runtime.GlobalMaxInflight
}
for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n > 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}

25
internal/config/logger.go Normal file
View File

@@ -0,0 +1,25 @@
package config
import (
"log/slog"
"os"
"strings"
)
var Logger = newLogger()
func newLogger() *slog.Logger {
level := new(slog.LevelVar)
switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
case "DEBUG":
level.Set(slog.LevelDebug)
case "WARN":
level.Set(slog.LevelWarn)
case "ERROR":
level.Set(slog.LevelError)
default:
level.Set(slog.LevelInfo)
}
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
return slog.New(h)
}

42
internal/config/paths.go Normal file
View File

@@ -0,0 +1,42 @@
package config
import (
"os"
"path/filepath"
"strings"
)
func BaseDir() string {
cwd, err := os.Getwd()
if err != nil {
return "."
}
return cwd
}
func IsVercel() bool {
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
}
func ResolvePath(envKey, defaultRel string) string {
raw := strings.TrimSpace(os.Getenv(envKey))
if raw != "" {
if filepath.IsAbs(raw) {
return raw
}
return filepath.Join(BaseDir(), raw)
}
return filepath.Join(BaseDir(), defaultRel)
}
func ConfigPath() string {
return ResolvePath("DS2API_CONFIG_PATH", "config.json")
}
func WASMPath() string {
return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
}
func StaticAdminDir() string {
return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
}

193
internal/config/store.go Normal file
View File

@@ -0,0 +1,193 @@
package config
import (
"encoding/base64"
"encoding/json"
"errors"
"os"
"slices"
"strings"
"sync"
)
type Store struct {
mu sync.RWMutex
cfg Config
path string
fromEnv bool
keyMap map[string]struct{} // O(1) API key lookup index
accMap map[string]int // O(1) account lookup: identifier -> slice index
}
func LoadStore() *Store {
cfg, fromEnv, err := loadConfig()
if err != nil {
Logger.Warn("[config] load failed", "error", err)
}
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
Logger.Warn("[config] empty config loaded")
}
s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
s.rebuildIndexes()
return s
}
func loadConfig() (Config, bool, error) {
rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
if rawCfg == "" {
rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
}
if rawCfg != "" {
cfg, err := parseConfigString(rawCfg)
return cfg, true, err
}
content, err := os.ReadFile(ConfigPath())
if err != nil {
if IsVercel() {
// Vercel one-click deploy may start without a writable/present config file.
// Keep an in-memory config so users can bootstrap via WebUI then sync env.
return Config{}, true, nil
}
return Config{}, false, err
}
var cfg Config
if err := json.Unmarshal(content, &cfg); err != nil {
return Config{}, false, err
}
if IsVercel() {
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
return cfg, true, nil
}
return cfg, false, nil
}
func (s *Store) Snapshot() Config {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cfg.Clone()
}
func (s *Store) HasAPIKey(k string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.keyMap[k]
return ok
}
func (s *Store) Keys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.cfg.Keys)
}
func (s *Store) Accounts() []Account {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.cfg.Accounts)
}
func (s *Store) FindAccount(identifier string) (Account, bool) {
identifier = strings.TrimSpace(identifier)
s.mu.RLock()
defer s.mu.RUnlock()
if idx, ok := s.findAccountIndexLocked(identifier); ok {
return s.cfg.Accounts[idx], true
}
return Account{}, false
}
func (s *Store) UpdateAccountToken(identifier, token string) error {
identifier = strings.TrimSpace(identifier)
s.mu.Lock()
defer s.mu.Unlock()
idx, ok := s.findAccountIndexLocked(identifier)
if !ok {
return errors.New("account not found")
}
oldID := s.cfg.Accounts[idx].Identifier()
s.cfg.Accounts[idx].Token = token
newID := s.cfg.Accounts[idx].Identifier()
// Keep historical aliases usable for long-lived queues while also adding
// the latest identifier after token refresh.
if identifier != "" {
s.accMap[identifier] = idx
}
if oldID != "" {
s.accMap[oldID] = idx
}
if newID != "" {
s.accMap[newID] = idx
}
return s.saveLocked()
}
func (s *Store) Replace(cfg Config) error {
s.mu.Lock()
defer s.mu.Unlock()
s.cfg = cfg.Clone()
s.rebuildIndexes()
return s.saveLocked()
}
func (s *Store) Update(mutator func(*Config) error) error {
s.mu.Lock()
defer s.mu.Unlock()
cfg := s.cfg.Clone()
if err := mutator(&cfg); err != nil {
return err
}
s.cfg = cfg
s.rebuildIndexes()
return s.saveLocked()
}
func (s *Store) Save() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.fromEnv {
Logger.Info("[save_config] source from env, skip write")
return nil
}
b, err := json.MarshalIndent(s.cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, b, 0o644)
}
func (s *Store) saveLocked() error {
if s.fromEnv {
Logger.Info("[save_config] source from env, skip write")
return nil
}
b, err := json.MarshalIndent(s.cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, b, 0o644)
}
func (s *Store) IsEnvBacked() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fromEnv
}
func (s *Store) SetVercelSync(hash string, ts int64) error {
return s.Update(func(c *Config) error {
c.VercelSyncHash = hash
c.VercelSyncTime = ts
return nil
})
}
func (s *Store) ExportJSONAndBase64() (string, string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
b, err := json.Marshal(s.cfg)
if err != nil {
return "", "", err
}
return string(b), base64.StdEncoding.EncodeToString(b), nil
}

View File

@@ -0,0 +1,167 @@
package config
import (
"os"
"strconv"
"strings"
)
func (s *Store) ClaudeMapping() map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.cfg.ClaudeModelMap) > 0 {
return cloneStringMap(s.cfg.ClaudeModelMap)
}
if len(s.cfg.ClaudeMapping) > 0 {
return cloneStringMap(s.cfg.ClaudeMapping)
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}
func (s *Store) ModelAliases() map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
out := DefaultModelAliases()
for k, v := range s.cfg.ModelAliases {
key := strings.TrimSpace(lower(k))
val := strings.TrimSpace(lower(v))
if key == "" || val == "" {
continue
}
out[key] = val
}
return out
}
func (s *Store) CompatWideInputStrictOutput() bool {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Compat.WideInputStrictOutput == nil {
return true
}
return *s.cfg.Compat.WideInputStrictOutput
}
func (s *Store) ToolcallMode() string {
s.mu.RLock()
defer s.mu.RUnlock()
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
if mode == "" {
return "feature_match"
}
return mode
}
func (s *Store) ToolcallEarlyEmitConfidence() string {
s.mu.RLock()
defer s.mu.RUnlock()
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
if level == "" {
return "high"
}
return level
}
func (s *Store) ResponsesStoreTTLSeconds() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Responses.StoreTTLSeconds > 0 {
return s.cfg.Responses.StoreTTLSeconds
}
return 900
}
func (s *Store) EmbeddingsProvider() string {
s.mu.RLock()
defer s.mu.RUnlock()
return strings.TrimSpace(s.cfg.Embeddings.Provider)
}
func (s *Store) AdminPasswordHash() string {
s.mu.RLock()
defer s.mu.RUnlock()
return strings.TrimSpace(s.cfg.Admin.PasswordHash)
}
func (s *Store) AdminJWTExpireHours() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Admin.JWTExpireHours > 0 {
return s.cfg.Admin.JWTExpireHours
}
if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
if n, err := strconv.Atoi(raw); err == nil && n > 0 {
return n
}
}
return 24
}
func (s *Store) AdminJWTValidAfterUnix() int64 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cfg.Admin.JWTValidAfterUnix
}
func (s *Store) RuntimeAccountMaxInflight() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Runtime.AccountMaxInflight > 0 {
return s.cfg.Runtime.AccountMaxInflight
}
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n > 0 {
return n
}
}
return 2
}
func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Runtime.AccountMaxQueue > 0 {
return s.cfg.Runtime.AccountMaxQueue
}
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n >= 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}
func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Runtime.GlobalMaxInflight > 0 {
return s.cfg.Runtime.GlobalMaxInflight
}
for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
continue
}
n, err := strconv.Atoi(raw)
if err == nil && n > 0 {
return n
}
}
if defaultSize < 0 {
return 0
}
return defaultSize
}

View File

@@ -0,0 +1,31 @@
package config
// rebuildIndexes must be called with the lock already held (or during init).
func (s *Store) rebuildIndexes() {
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
for _, k := range s.cfg.Keys {
s.keyMap[k] = struct{}{}
}
s.accMap = make(map[string]int, len(s.cfg.Accounts))
for i, acc := range s.cfg.Accounts {
id := acc.Identifier()
if id != "" {
s.accMap[id] = i
}
}
}
// findAccountIndexLocked expects the store lock to already be held.
func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
return idx, true
}
// Fallback for token-only accounts whose derived identifier changed after
// a token refresh; this preserves correctness on map misses.
for i, acc := range s.cfg.Accounts {
if acc.Identifier() == identifier {
return i, true
}
}
return -1, false
}

View File

@@ -1,347 +0,0 @@
package deepseek
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
"ds2api/internal/devcapture"
"ds2api/internal/util"
"github.com/andybalholm/brotli"
)
// intFrom is a package-internal alias for the shared util version.
var intFrom = util.IntFrom
type Client struct {
Store *config.Store
Auth *auth.Resolver
capture *devcapture.Store
regular trans.Doer
stream trans.Doer
fallback *http.Client
fallbackS *http.Client
powSolver *PowSolver
maxRetries int
}
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
return &Client{
Store: store,
Auth: resolver,
capture: devcapture.Global(),
regular: trans.New(60 * time.Second),
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
fallbackS: &http.Client{Timeout: 0},
powSolver: NewPowSolver(config.WASMPath()),
maxRetries: 3,
}
}
func (c *Client) PreloadPow(ctx context.Context) error {
return c.powSolver.init(ctx)
}
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
payload := map[string]any{
"password": strings.TrimSpace(acc.Password),
"device_id": "deepseek_to_api",
"os": "android",
}
if email := strings.TrimSpace(acc.Email); email != "" {
payload["email"] = email
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
payload["mobile"] = mobile
payload["area_code"] = nil
} else {
return "", errors.New("missing email/mobile")
}
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
if err != nil {
return "", err
}
code := intFrom(resp["code"])
if code != 0 {
return "", fmt.Errorf("login failed: %v", resp["msg"])
}
data, _ := resp["data"].(map[string]any)
if intFrom(data["biz_code"]) != 0 {
return "", fmt.Errorf("login failed: %v", data["biz_msg"])
}
bizData, _ := data["biz_data"].(map[string]any)
user, _ := bizData["user"].(map[string]any)
token, _ := user["token"].(string)
if strings.TrimSpace(token) == "" {
return "", errors.New("missing login token")
}
return token, nil
}
func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
attempts := 0
refreshed := false
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
if err != nil {
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
attempts++
continue
}
code := intFrom(resp["code"])
if status == http.StatusOK && code == 0 {
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
sessionID, _ := bizData["id"].(string)
if sessionID != "" {
return sessionID, nil
}
}
msg, _ := resp["msg"].(string)
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
if a.UseConfigToken {
if isTokenInvalid(status, code, msg) && !refreshed {
if c.Auth.RefreshToken(ctx, a) {
refreshed = true
continue
}
}
if c.Auth.SwitchAccount(ctx, a) {
refreshed = false
attempts++
continue
}
}
attempts++
}
return "", errors.New("create session failed")
}
func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
attempts := 0
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
if err != nil {
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
attempts++
continue
}
code := intFrom(resp["code"])
if status == http.StatusOK && code == 0 {
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
challenge, _ := bizData["challenge"].(map[string]any)
answer, err := c.powSolver.Compute(ctx, challenge)
if err != nil {
attempts++
continue
}
return BuildPowHeader(challenge, answer)
}
msg, _ := resp["msg"].(string)
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
if a.UseConfigToken {
if isTokenInvalid(status, code, msg) {
if c.Auth.RefreshToken(ctx, a) {
continue
}
}
if c.Auth.SwitchAccount(ctx, a) {
attempts++
continue
}
}
attempts++
}
return "", errors.New("get pow failed")
}
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
headers := c.authHeaders(a.DeepSeekToken)
headers["x-ds-pow-response"] = powResp
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
attempts := 0
for attempts < maxAttempts {
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
if err != nil {
attempts++
time.Sleep(time.Second)
continue
}
if resp.StatusCode == http.StatusOK {
if captureSession != nil {
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
}
return resp, nil
}
if captureSession != nil {
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
}
_ = resp.Body.Close()
attempts++
time.Sleep(time.Second)
}
return nil, errors.New("completion failed")
}
func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
if err != nil {
return nil, err
}
if status == 0 {
return nil, errors.New("request failed")
}
return body, nil
}
func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
b, err := json.Marshal(payload)
if err != nil {
return nil, 0, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
return nil, 0, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := doer.Do(req)
if err != nil {
config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if reqErr != nil {
return nil, 0, err
}
for k, v := range headers {
req2.Header.Set(k, v)
}
resp, err = c.fallback.Do(req2)
if err != nil {
return nil, 0, err
}
}
defer resp.Body.Close()
payloadBytes, err := readResponseBody(resp)
if err != nil {
return nil, resp.StatusCode, err
}
out := map[string]any{}
if len(payloadBytes) > 0 {
if err := json.Unmarshal(payloadBytes, &out); err != nil {
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
}
}
return out, resp.StatusCode, nil
}
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := c.stream.Do(req)
if err != nil {
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if reqErr != nil {
return nil, err
}
for k, v := range headers {
req2.Header.Set(k, v)
}
return c.fallbackS.Do(req2)
}
return resp, nil
}
func (c *Client) authHeaders(token string) map[string]string {
headers := make(map[string]string, len(BaseHeaders)+1)
for k, v := range BaseHeaders {
headers[k] = v
}
headers["authorization"] = "Bearer " + token
return headers
}
func isTokenInvalid(status int, code int, msg string) bool {
msg = strings.ToLower(msg)
if status == http.StatusUnauthorized || status == http.StatusForbidden {
return true
}
if code == 40001 || code == 40002 || code == 40003 {
return true
}
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
}
func readResponseBody(resp *http.Response) ([]byte, error) {
encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
var reader io.Reader = resp.Body
switch encoding {
case "gzip":
gz, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
defer gz.Close()
reader = gz
case "br":
reader = brotli.NewReader(resp.Body)
}
return io.ReadAll(reader)
}
func preview(b []byte) string {
s := strings.TrimSpace(string(b))
if len(s) > 160 {
return s[:160]
}
return s
}
func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 2*1024*1024)
for scanner.Scan() {
if !onLine(scanner.Bytes()) {
break
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,153 @@
package deepseek
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"ds2api/internal/auth"
"ds2api/internal/config"
)
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
payload := map[string]any{
"password": strings.TrimSpace(acc.Password),
"device_id": "deepseek_to_api",
"os": "android",
}
if email := strings.TrimSpace(acc.Email); email != "" {
payload["email"] = email
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
payload["mobile"] = mobile
payload["area_code"] = nil
} else {
return "", errors.New("missing email/mobile")
}
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
if err != nil {
return "", err
}
code := intFrom(resp["code"])
if code != 0 {
return "", fmt.Errorf("login failed: %v", resp["msg"])
}
data, _ := resp["data"].(map[string]any)
if intFrom(data["biz_code"]) != 0 {
return "", fmt.Errorf("login failed: %v", data["biz_msg"])
}
bizData, _ := data["biz_data"].(map[string]any)
user, _ := bizData["user"].(map[string]any)
token, _ := user["token"].(string)
if strings.TrimSpace(token) == "" {
return "", errors.New("missing login token")
}
return token, nil
}
func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
attempts := 0
refreshed := false
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
if err != nil {
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
attempts++
continue
}
code := intFrom(resp["code"])
if status == http.StatusOK && code == 0 {
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
sessionID, _ := bizData["id"].(string)
if sessionID != "" {
return sessionID, nil
}
}
msg, _ := resp["msg"].(string)
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
if a.UseConfigToken {
if isTokenInvalid(status, code, msg) && !refreshed {
if c.Auth.RefreshToken(ctx, a) {
refreshed = true
continue
}
}
if c.Auth.SwitchAccount(ctx, a) {
refreshed = false
attempts++
continue
}
}
attempts++
}
return "", errors.New("create session failed")
}
func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
attempts := 0
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
if err != nil {
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
attempts++
continue
}
code := intFrom(resp["code"])
if status == http.StatusOK && code == 0 {
data, _ := resp["data"].(map[string]any)
bizData, _ := data["biz_data"].(map[string]any)
challenge, _ := bizData["challenge"].(map[string]any)
answer, err := c.powSolver.Compute(ctx, challenge)
if err != nil {
attempts++
continue
}
return BuildPowHeader(challenge, answer)
}
msg, _ := resp["msg"].(string)
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
if a.UseConfigToken {
if isTokenInvalid(status, code, msg) {
if c.Auth.RefreshToken(ctx, a) {
continue
}
}
if c.Auth.SwitchAccount(ctx, a) {
attempts++
continue
}
}
attempts++
}
return "", errors.New("get pow failed")
}
func (c *Client) authHeaders(token string) map[string]string {
headers := make(map[string]string, len(BaseHeaders)+1)
for k, v := range BaseHeaders {
headers[k] = v
}
headers["authorization"] = "Bearer " + token
return headers
}
func isTokenInvalid(status int, code int, msg string) bool {
msg = strings.ToLower(msg)
if status == http.StatusUnauthorized || status == http.StatusForbidden {
return true
}
if code == 40001 || code == 40002 || code == 40003 {
return true
}
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
}

View File

@@ -0,0 +1,71 @@
package deepseek
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"time"
"ds2api/internal/auth"
"ds2api/internal/config"
)
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
headers := c.authHeaders(a.DeepSeekToken)
headers["x-ds-pow-response"] = powResp
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
attempts := 0
for attempts < maxAttempts {
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
if err != nil {
attempts++
time.Sleep(time.Second)
continue
}
if resp.StatusCode == http.StatusOK {
if captureSession != nil {
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
}
return resp, nil
}
if captureSession != nil {
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
}
_ = resp.Body.Close()
attempts++
time.Sleep(time.Second)
}
return nil, errors.New("completion failed")
}
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := c.stream.Do(req)
if err != nil {
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if reqErr != nil {
return nil, err
}
for k, v := range headers {
req2.Header.Set(k, v)
}
return c.fallbackS.Do(req2)
}
return resp, nil
}

View File

@@ -0,0 +1,46 @@
package deepseek
import (
"context"
"net/http"
"time"
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
"ds2api/internal/devcapture"
"ds2api/internal/util"
)
// intFrom is a package-internal alias for the shared util version.
var intFrom = util.IntFrom
type Client struct {
Store *config.Store
Auth *auth.Resolver
capture *devcapture.Store
regular trans.Doer
stream trans.Doer
fallback *http.Client
fallbackS *http.Client
powSolver *PowSolver
maxRetries int
}
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
return &Client{
Store: store,
Auth: resolver,
capture: devcapture.Global(),
regular: trans.New(60 * time.Second),
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
fallbackS: &http.Client{Timeout: 0},
powSolver: NewPowSolver(config.WASMPath()),
maxRetries: 3,
}
}
func (c *Client) PreloadPow(ctx context.Context) error {
return c.powSolver.init(ctx)
}

View File

@@ -0,0 +1,51 @@
package deepseek
import (
"bufio"
"compress/gzip"
"io"
"net/http"
"strings"
"github.com/andybalholm/brotli"
)
func readResponseBody(resp *http.Response) ([]byte, error) {
encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
var reader io.Reader = resp.Body
switch encoding {
case "gzip":
gz, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
defer gz.Close()
reader = gz
case "br":
reader = brotli.NewReader(resp.Body)
}
return io.ReadAll(reader)
}
func preview(b []byte) string {
s := strings.TrimSpace(string(b))
if len(s) > 160 {
return s[:160]
}
return s
}
func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 2*1024*1024)
for scanner.Scan() {
if !onLine(scanner.Bytes()) {
break
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,64 @@
package deepseek
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
)
func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
if err != nil {
return nil, err
}
if status == 0 {
return nil, errors.New("request failed")
}
return body, nil
}
func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
b, err := json.Marshal(payload)
if err != nil {
return nil, 0, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
return nil, 0, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := doer.Do(req)
if err != nil {
config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if reqErr != nil {
return nil, 0, err
}
for k, v := range headers {
req2.Header.Set(k, v)
}
resp, err = c.fallback.Do(req2)
if err != nil {
return nil, 0, err
}
}
defer resp.Body.Close()
payloadBytes, err := readResponseBody(resp)
if err != nil {
return nil, resp.StatusCode, err
}
out := map[string]any{}
if len(payloadBytes) > 0 {
if err := json.Unmarshal(payloadBytes, &out); err != nil {
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
}
}
return out, resp.StatusCode, nil
}

View File

@@ -1,307 +0,0 @@
package openai
import (
"encoding/json"
"strings"
"time"
"github.com/google/uuid"
"ds2api/internal/util"
)
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := util.ParseToolCalls(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" {
messageObj["reasoning_content"] = finalThinking
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
messageObj["content"] = nil
}
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"id": completionID,
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
"usage": map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
},
}
}
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
// Align responses tool-call semantics with chat/completions:
// mixed prose + tool_call payloads should still be interpreted as tool calls.
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
exposedOutputText := finalText
output := make([]any, 0, 2)
if len(detected) > 0 {
exposedOutputText = ""
if strings.TrimSpace(finalThinking) != "" {
output = append(output, map[string]any{
"type": "reasoning",
"text": finalThinking,
})
}
formatted := util.FormatOpenAIToolCalls(detected)
output = append(output, toResponsesFunctionCallItems(formatted)...)
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": formatted,
})
} else {
content := make([]any, 0, 2)
if finalThinking != "" {
content = append([]any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}}, content...)
}
if strings.TrimSpace(finalText) != "" {
content = append(content, map[string]any{
"type": "output_text",
"text": finalText,
})
}
if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
exposedOutputText = finalThinking
}
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"role": "assistant",
"content": content,
})
}
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"id": responseID,
"type": "response",
"object": "response",
"created_at": time.Now().Unix(),
"status": "completed",
"model": model,
"output": output,
"output_text": exposedOutputText,
"usage": map[string]any{
"input_tokens": promptTokens,
"output_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
},
}
}
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
if len(toolCalls) == 0 {
return nil
}
out := make([]any, 0, len(toolCalls))
for _, tc := range toolCalls {
callID, _ := tc["id"].(string)
if strings.TrimSpace(callID) == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
name := ""
args := "{}"
if fn, ok := tc["function"].(map[string]any); ok {
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
name = n
}
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
args = a
}
}
out = append(out, map[string]any{
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function_call",
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(args),
"status": "completed",
})
}
return out
}
func normalizeJSONString(raw string) string {
s := strings.TrimSpace(raw)
if s == "" {
return "{}"
}
var v any
if err := json.Unmarshal([]byte(s), &v); err != nil {
return raw
}
b, err := json.Marshal(v)
if err != nil {
return raw
}
return string(b)
}
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
return map[string]any{
"delta": delta,
"index": index,
}
}
func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
return map[string]any{
"delta": map[string]any{},
"index": index,
"finish_reason": finishReason,
}
}
func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
out := map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": choices,
}
if len(usage) > 0 {
out["usage"] = usage
}
return out
}
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
}
}
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
"type": "response.created",
"id": responseID,
"response_id": responseID,
"object": "response",
"model": model,
"status": "in_progress",
}
}
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"response_id": responseID,
"delta": delta,
}
}
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning.delta",
"id": responseID,
"response_id": responseID,
"delta": delta,
}
}
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"delta": delta,
}
}
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"text": text,
}
}
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"call_id": callID,
"delta": delta,
}
}
func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(arguments),
}
}
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
responseID, _ := response["id"].(string)
return map[string]any{
"type": "response.completed",
"response_id": responseID,
"response": response,
}
}

View File

@@ -0,0 +1,60 @@
package openai
import (
"strings"
"time"
"ds2api/internal/util"
)
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := util.ParseToolCalls(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" {
messageObj["reasoning_content"] = finalThinking
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
messageObj["content"] = nil
}
return map[string]any{
"id": completionID,
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
"usage": BuildChatUsage(finalPrompt, finalThinking, finalText),
}
}
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
return map[string]any{
"delta": delta,
"index": index,
}
}
func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
return map[string]any{
"delta": map[string]any{},
"index": index,
"finish_reason": finishReason,
}
}
func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
out := map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": choices,
}
if len(usage) > 0 {
out["usage"] = usage
}
return out
}

View File

@@ -0,0 +1,119 @@
package openai
import (
"encoding/json"
"strings"
"time"
"github.com/google/uuid"
"ds2api/internal/util"
)
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
// Align responses tool-call semantics with chat/completions:
// mixed prose + tool_call payloads should still be interpreted as tool calls.
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
exposedOutputText := finalText
output := make([]any, 0, 2)
if len(detected) > 0 {
exposedOutputText = ""
if strings.TrimSpace(finalThinking) != "" {
output = append(output, map[string]any{
"type": "reasoning",
"text": finalThinking,
})
}
formatted := util.FormatOpenAIToolCalls(detected)
output = append(output, toResponsesFunctionCallItems(formatted)...)
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": formatted,
})
} else {
content := make([]any, 0, 2)
if finalThinking != "" {
content = append([]any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}}, content...)
}
if strings.TrimSpace(finalText) != "" {
content = append(content, map[string]any{
"type": "output_text",
"text": finalText,
})
}
if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
exposedOutputText = finalThinking
}
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"role": "assistant",
"content": content,
})
}
return map[string]any{
"id": responseID,
"type": "response",
"object": "response",
"created_at": time.Now().Unix(),
"status": "completed",
"model": model,
"output": output,
"output_text": exposedOutputText,
"usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText),
}
}
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
if len(toolCalls) == 0 {
return nil
}
out := make([]any, 0, len(toolCalls))
for _, tc := range toolCalls {
callID, _ := tc["id"].(string)
if strings.TrimSpace(callID) == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
name := ""
args := "{}"
if fn, ok := tc["function"].(map[string]any); ok {
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
name = n
}
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
args = a
}
}
out = append(out, map[string]any{
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function_call",
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(args),
"status": "completed",
})
}
return out
}
func normalizeJSONString(raw string) string {
s := strings.TrimSpace(raw)
if s == "" {
return "{}"
}
var v any
if err := json.Unmarshal([]byte(s), &v); err != nil {
return raw
}
b, err := json.Marshal(v)
if err != nil {
return raw
}
return string(b)
}

View File

@@ -0,0 +1,106 @@
package openai
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
"type": "response.created",
"id": responseID,
"response_id": responseID,
"object": "response",
"model": model,
"status": "in_progress",
}
}
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"response_id": responseID,
"delta": delta,
}
}
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning.delta",
"id": responseID,
"response_id": responseID,
"delta": delta,
}
}
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"delta": delta,
}
}
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"text": text,
}
}
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"call_id": callID,
"delta": delta,
}
}
func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(arguments),
}
}
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
responseID, _ := response["id"].(string)
return map[string]any{
"type": "response.completed",
"response_id": responseID,
"response": response,
}
}

View File

@@ -0,0 +1,28 @@
package openai
import "ds2api/internal/util"
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
"completion_tokens_details": map[string]any{
"reasoning_tokens": reasoningTokens,
},
}
}
func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"input_tokens": promptTokens,
"output_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
}
}

View File

@@ -1,7 +1,6 @@
package testsuite
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -125,72 +124,6 @@ func (r *Runner) caseStreamAbortRelease(ctx context.Context, cc *caseContext) er
return nil
}
func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error {
cc.seq++
traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq)
cc.traceIDsSet[traceID] = struct{}{}
fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
if err != nil {
return err
}
headers := map[string]string{}
for k, v := range spec.Headers {
headers[k] = v
}
headers["X-Ds2-Test-Trace"] = traceID
bodyBytes, _ := json.Marshal(spec.Body)
headers["Content-Type"] = "application/json"
cc.requests = append(cc.requests, requestLog{
Seq: cc.seq,
Attempt: 1,
TraceID: traceID,
Method: spec.Method,
URL: fullURL,
Headers: headers,
Body: spec.Body,
Timestamp: time.Now().Format(time.RFC3339Nano),
})
reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
if err != nil {
return err
}
for k, v := range headers {
req.Header.Set(k, v)
}
start := time.Now()
resp, err := cc.runner.httpClient.Do(req)
if err != nil {
cc.responses = append(cc.responses, responseLog{
Seq: cc.seq,
Attempt: 1,
TraceID: traceID,
StatusCode: 0,
DurationMS: time.Since(start).Milliseconds(),
NetworkErr: err.Error(),
ReceivedAt: time.Now().Format(time.RFC3339Nano),
})
return err
}
defer resp.Body.Close()
buf := make([]byte, 512)
_, _ = resp.Body.Read(buf)
_ = resp.Body.Close()
cc.responses = append(cc.responses, responseLog{
Seq: cc.seq,
Attempt: 1,
TraceID: traceID,
StatusCode: resp.StatusCode,
Headers: resp.Header,
BodyText: "aborted_after_first_chunk",
DurationMS: time.Since(start).Milliseconds(),
ReceivedAt: time.Now().Format(time.RFC3339Nano),
})
return nil
}
func (r *Runner) caseToolcallStreamMixed(ctx context.Context, cc *caseContext) error {
payload := toolcallPayload(true)
payload["messages"] = []map[string]any{
@@ -293,167 +226,6 @@ func (r *Runner) caseSSEJSONIntegrity(ctx context.Context, cc *caseContext) erro
return nil
}
func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error {
resp, err := cc.requestOnce(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-not-exists",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"stream": false,
},
Retryable: false,
}, 1)
if err != nil {
return err
}
cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
e, _ := m["error"].(map[string]any)
cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-chat",
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
e, _ := m["error"].(map[string]any)
cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/config",
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
return nil
}
func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error {
if len(r.configRaw.Accounts) == 0 {
cc.assert("account_present", false, "no account in config")
return nil
}
acc := r.configRaw.Accounts[0]
id := strings.TrimSpace(acc.Email)
if id == "" {
id = strings.TrimSpace(acc.Mobile)
}
if id == "" {
cc.assert("account_identifier", false, "first account has no identifier")
return nil
}
if strings.TrimSpace(acc.Password) == "" {
r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty")
cc.assert("account_password_present", true, "skipped strict refresh check due empty password")
return nil
}
invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID)
update := map[string]any{
"keys": r.configRaw.Keys,
"accounts": []map[string]any{
{
"email": acc.Email,
"mobile": acc.Mobile,
"password": acc.Password,
"token": invalidToken,
},
},
}
updResp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/admin/config",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Body: update,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode))
chatResp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
"X-Ds2-Target-Account": id,
},
Body: map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{"role": "user", "content": "token refresh test"},
},
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body)))
cfgResp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/config",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Retryable: true,
})
if err != nil {
return err
}
var cfg map[string]any
_ = json.Unmarshal(cfgResp.Body, &cfg)
accounts, _ := cfg["accounts"].([]any)
preview := ""
hasToken := false
for _, item := range accounts {
m, _ := item.(map[string]any)
e := asString(m["email"])
mo := asString(m["mobile"])
if e == acc.Email && mo == acc.Mobile {
preview = asString(m["token_preview"])
hasToken, _ = m["has_token"].(bool)
break
}
}
cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body)))
cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20]))
return nil
}
func (r *Runner) fetchQueueStatus(ctx context.Context, cc *caseContext) (map[string]any, error) {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,

View File

@@ -0,0 +1,76 @@
package testsuite
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error {
cc.seq++
traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq)
cc.traceIDsSet[traceID] = struct{}{}
fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
if err != nil {
return err
}
headers := map[string]string{}
for k, v := range spec.Headers {
headers[k] = v
}
headers["X-Ds2-Test-Trace"] = traceID
bodyBytes, _ := json.Marshal(spec.Body)
headers["Content-Type"] = "application/json"
cc.requests = append(cc.requests, requestLog{
Seq: cc.seq,
Attempt: 1,
TraceID: traceID,
Method: spec.Method,
URL: fullURL,
Headers: headers,
Body: spec.Body,
Timestamp: time.Now().Format(time.RFC3339Nano),
})
reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
if err != nil {
return err
}
for k, v := range headers {
req.Header.Set(k, v)
}
start := time.Now()
resp, err := cc.runner.httpClient.Do(req)
if err != nil {
cc.responses = append(cc.responses, responseLog{
Seq: cc.seq,
Attempt: 1,
TraceID: traceID,
StatusCode: 0,
DurationMS: time.Since(start).Milliseconds(),
NetworkErr: err.Error(),
ReceivedAt: time.Now().Format(time.RFC3339Nano),
})
return err
}
defer resp.Body.Close()
buf := make([]byte, 512)
_, _ = resp.Body.Read(buf)
_ = resp.Body.Close()
cc.responses = append(cc.responses, responseLog{
Seq: cc.seq,
Attempt: 1,
TraceID: traceID,
StatusCode: resp.StatusCode,
Headers: resp.Header,
BodyText: "aborted_after_first_chunk",
DurationMS: time.Since(start).Milliseconds(),
ReceivedAt: time.Now().Format(time.RFC3339Nano),
})
return nil
}

View File

@@ -0,0 +1,170 @@
package testsuite
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
)
func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error {
resp, err := cc.requestOnce(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-not-exists",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"stream": false,
},
Retryable: false,
}, 1)
if err != nil {
return err
}
cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
e, _ := m["error"].(map[string]any)
cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-chat",
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
e, _ := m["error"].(map[string]any)
cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/config",
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
return nil
}
func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error {
if len(r.configRaw.Accounts) == 0 {
cc.assert("account_present", false, "no account in config")
return nil
}
acc := r.configRaw.Accounts[0]
id := strings.TrimSpace(acc.Email)
if id == "" {
id = strings.TrimSpace(acc.Mobile)
}
if id == "" {
cc.assert("account_identifier", false, "first account has no identifier")
return nil
}
if strings.TrimSpace(acc.Password) == "" {
r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty")
cc.assert("account_password_present", true, "skipped strict refresh check due empty password")
return nil
}
invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID)
update := map[string]any{
"keys": r.configRaw.Keys,
"accounts": []map[string]any{
{
"email": acc.Email,
"mobile": acc.Mobile,
"password": acc.Password,
"token": invalidToken,
},
},
}
updResp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/admin/config",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Body: update,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode))
chatResp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
"X-Ds2-Target-Account": id,
},
Body: map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{"role": "user", "content": "token refresh test"},
},
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body)))
cfgResp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/config",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Retryable: true,
})
if err != nil {
return err
}
var cfg map[string]any
_ = json.Unmarshal(cfgResp.Body, &cfg)
accounts, _ := cfg["accounts"].([]any)
preview := ""
hasToken := false
for _, item := range accounts {
m, _ := item.(map[string]any)
e := asString(m["email"])
mo := asString(m["mobile"])
if e == acc.Email && mo == acc.Mobile {
preview = asString(m["token_preview"])
hasToken, _ = m["has_token"].(bool)
break
}
}
cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body)))
cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20]))
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,161 @@
package testsuite
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
)
func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error {
loginResp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/admin/login",
Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode))
var payload map[string]any
_ = json.Unmarshal(loginResp.Body, &payload)
token := asString(payload["token"])
cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body)))
if token == "" {
return nil
}
verifyResp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/verify",
Headers: map[string]string{
"Authorization": "Bearer " + token,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode))
var v map[string]any
_ = json.Unmarshal(verifyResp.Body, &v)
valid, _ := v["valid"].(bool)
cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body)))
return nil
}
func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/queue/status",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
_, hasRec := m["recommended_concurrency"]
_, hasQueue := m["max_queue_size"]
cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error {
if strings.TrimSpace(r.accountID) == "" {
cc.assert("account_present", false, "no account in config")
return nil
}
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/admin/accounts/test",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Body: map[string]any{
"identifier": r.accountID,
"model": "deepseek-chat",
"message": "ping",
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
ok, _ := m["success"].(bool)
cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error {
k := "testsuite-temp-" + sanitizeID(r.runID)
add, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/admin/keys",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Body: map[string]any{"key": k},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode))
cfg1, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/config",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Retryable: true,
})
if err != nil {
return err
}
containsAdded := strings.Contains(string(cfg1.Body), k)
cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config")
delPath := "/admin/keys/" + url.PathEscape(k)
del, err := cc.request(ctx, requestSpec{
Method: http.MethodDelete,
Path: delPath,
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode))
cfg2, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/admin/config",
Headers: map[string]string{
"Authorization": "Bearer " + r.adminJWT,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present")
if err := r.ensureOriginalConfigUntouched(); err != nil {
cc.assert("original_config_unchanged", false, err.Error())
} else {
cc.assert("original_config_unchanged", true, "")
}
return nil
}

View File

@@ -0,0 +1,103 @@
package testsuite
import (
"context"
"encoding/json"
"fmt"
"net/http"
)
func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
ids := extractModelIDs(resp.Body)
cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids))
return nil
}
func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/anthropic/v1/messages",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
},
Body: map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hello"},
},
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/anthropic/v1/messages",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
},
Body: map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "stream hello"},
},
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
events := parseClaudeStreamEvents(resp.Body)
cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events))
cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events))
return nil
}
func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/anthropic/v1/messages/count_tokens",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
},
Body: map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "count me"},
},
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
v := toInt(m["input_tokens"])
cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}

View File

@@ -0,0 +1,221 @@
package testsuite
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
)
func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
ids := extractModelIDs(resp.Body)
cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ","))
cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ","))
return nil
}
func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{"role": "user", "content": "请简单回复一句话"},
},
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body)))
choices, _ := m["choices"].([]any)
cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{"role": "user", "content": "请流式回复一句话"},
},
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
frames, done := parseSSEFrames(resp.Body)
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
cc.assert("done_terminated", done, "expected [DONE]")
return nil
}
func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/responses",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": "请简要回答 hello",
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
responseID := asString(m["id"])
cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
if responseID != "" {
getResp, getErr := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/v1/responses/" + responseID,
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Retryable: true,
})
if getErr != nil {
return getErr
}
cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
}
return nil
}
func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/responses",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": "请流式回答 hello",
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
frames, done := parseSSEFrames(resp.Body)
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
hasCreated := false
hasCompleted := false
for _, f := range frames {
switch asString(f["type"]) {
case "response.created":
hasCreated = true
case "response.completed":
hasCompleted = true
}
}
cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("done_terminated", done, "expected [DONE]")
return nil
}
func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/embeddings",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": []string{"hello", "world"},
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
if resp.StatusCode == http.StatusOK {
cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
data, _ := m["data"].([]any)
cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
errObj, _ := m["error"].(map[string]any)
_, hasCode := errObj["code"]
_, hasParam := errObj["param"]
cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}

View File

@@ -0,0 +1,236 @@
package testsuite
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
)
func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-reasoner",
"messages": []map[string]any{
{"role": "user", "content": "先思考后回答1+1"},
},
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
frames, done := parseSSEFrames(resp.Body)
hasReasoning := false
for _, f := range frames {
choices, _ := f["choices"].([]any)
for _, c := range choices {
ch, _ := c.(map[string]any)
delta, _ := ch["delta"].(map[string]any)
if asString(delta["reasoning_content"]) != "" {
hasReasoning = true
}
}
}
cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found")
cc.assert("done_terminated", done, "expected [DONE]")
return nil
}
func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: toolcallPayload(false),
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
choices, _ := m["choices"].([]any)
if len(choices) == 0 {
cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
c0, _ := choices[0].(map[string]any)
cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body)))
msg, _ := c0["message"].(map[string]any)
tc, _ := msg["tool_calls"].([]any)
cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: toolcallPayload(true),
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
frames, done := parseSSEFrames(resp.Body)
hasTool := false
rawLeak := false
for _, f := range frames {
choices, _ := f["choices"].([]any)
for _, c := range choices {
ch, _ := c.(map[string]any)
delta, _ := ch["delta"].(map[string]any)
if _, ok := delta["tool_calls"]; ok {
hasTool = true
}
content := asString(delta["content"])
if strings.Contains(strings.ToLower(content), `"tool_calls"`) {
rawLeak = true
}
}
}
cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing")
cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content")
cc.assert("done_terminated", done, "expected [DONE]")
return nil
}
func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error {
accountCount := len(r.configRaw.Accounts)
n := accountCount*2 + 2
if n < 2 {
n = 2
}
type one struct {
Status int
Err string
}
results := make([]one, n)
var wg sync.WaitGroup
for i := 0; i < n; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{"role": "user", "content": fmt.Sprintf("并发请求 #%d请回复ok", idx)},
},
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
results[idx] = one{Err: err.Error()}
return
}
results[idx] = one{Status: resp.StatusCode}
}(i)
}
wg.Wait()
dist := map[int]int{}
success := 0
for _, it := range results {
if it.Status > 0 {
dist[it.Status]++
if it.Status == http.StatusOK {
success++
}
}
}
cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist))
_, has5xx := has5xx(dist)
cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist))
if err := r.ping("/healthz"); err != nil {
cc.assert("server_alive", false, err.Error())
} else {
cc.assert("server_alive", true, "")
}
return nil
}
func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/chat/completions",
Headers: map[string]string{
"Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID),
},
Body: map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"stream": false,
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
e, _ := m["error"].(map[string]any)
cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func toolcallPayload(stream bool) map[string]any {
return map[string]any{
"model": "deepseek-chat",
"messages": []map[string]any{
{
"role": "user",
"content": "你必须调用工具 search 查询 golang并仅返回工具调用。",
},
},
"tools": []map[string]any{
{
"type": "function",
"function": map[string]any{
"name": "search",
"description": "search documents",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"q": map[string]any{
"type": "string",
},
},
"required": []string{"q"},
},
},
},
},
"stream": stream,
}
}

View File

@@ -0,0 +1,290 @@
package testsuite
import (
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
type Options struct {
ConfigPath string
AdminKey string
OutputDir string
Port int
Timeout time.Duration
Retries int
NoPreflight bool
MaxKeepRuns int
}
type runSummary struct {
RunID string `json:"run_id"`
StartedAt string `json:"started_at"`
EndedAt string `json:"ended_at"`
DurationMS int64 `json:"duration_ms"`
Stats map[string]any `json:"stats"`
Environment map[string]any `json:"environment"`
Cases []caseResult `json:"cases"`
Warnings []string `json:"warnings,omitempty"`
}
type caseResult struct {
CaseID string `json:"case_id"`
Passed bool `json:"passed"`
DurationMS int64 `json:"duration_ms"`
TraceIDs []string `json:"trace_ids"`
StatusCodes []int `json:"status_codes"`
Error string `json:"error,omitempty"`
ArtifactPath string `json:"artifact_path"`
Assertions []assertionResult `json:"assertions"`
}
type assertionResult struct {
Name string `json:"name"`
Passed bool `json:"passed"`
Detail string `json:"detail,omitempty"`
}
type requestLog struct {
Seq int `json:"seq"`
Attempt int `json:"attempt"`
TraceID string `json:"trace_id"`
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body any `json:"body,omitempty"`
Timestamp string `json:"timestamp"`
}
type responseLog struct {
Seq int `json:"seq"`
Attempt int `json:"attempt"`
TraceID string `json:"trace_id"`
StatusCode int `json:"status_code"`
Headers map[string][]string `json:"headers"`
BodyText string `json:"body_text"`
DurationMS int64 `json:"duration_ms"`
NetworkErr string `json:"network_error,omitempty"`
ReceivedAt string `json:"received_at"`
}
type caseContext struct {
runner *Runner
id string
dir string
startedAt time.Time
mu sync.Mutex
seq int
assertions []assertionResult
requests []requestLog
responses []responseLog
streamRaw strings.Builder
traceIDsSet map[string]struct{}
}
type requestSpec struct {
Method string
Path string
Headers map[string]string
Body any
Stream bool
Retryable bool
}
type responseResult struct {
StatusCode int
Headers http.Header
Body []byte
TraceID string
URL string
}
type Runner struct {
opts Options
runID string
runDir string
serverLog string
preflightLog string
baseURL string
httpClient *http.Client
serverCmd *exec.Cmd
serverLogFd *os.File
configCopyPath string
originalConfigPath string
originalConfigHash string
configRaw runConfig
apiKey string
adminKey string
adminJWT string
accountID string
warnings []string
results []caseResult
}
type runConfig struct {
Keys []string `json:"keys"`
Accounts []struct {
Email string `json:"email,omitempty"`
Mobile string `json:"mobile,omitempty"`
Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"`
} `json:"accounts"`
}
func Run(ctx context.Context, opts Options) error {
r, err := newRunner(opts)
if err != nil {
return err
}
start := time.Now()
defer func() {
_ = r.stopServer()
}()
if err := r.prepareRunDir(); err != nil {
return err
}
if !r.opts.NoPreflight {
if err := r.runPreflight(ctx); err != nil {
_ = r.writeSummary(start, time.Now())
return err
}
}
if err := r.prepareConfigIsolation(); err != nil {
_ = r.writeSummary(start, time.Now())
return err
}
if err := r.startServer(ctx); err != nil {
_ = r.writeSummary(start, time.Now())
return err
}
if err := r.prepareAuth(ctx); err != nil {
r.warnings = append(r.warnings, "auth prepare failed: "+err.Error())
}
for _, c := range r.cases() {
r.runCase(ctx, c)
}
if err := r.ensureOriginalConfigUntouched(); err != nil {
r.warnings = append(r.warnings, err.Error())
}
end := time.Now()
if err := r.writeSummary(start, end); err != nil {
return err
}
// Prune old test runs, keeping only the most recent N.
if err := r.pruneOldRuns(); err != nil {
r.warnings = append(r.warnings, "prune old runs: "+err.Error())
}
failed := 0
for _, cs := range r.results {
if !cs.Passed {
failed++
}
}
if failed > 0 {
return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md"))
}
return nil
}
func newRunner(opts Options) (*Runner, error) {
if strings.TrimSpace(opts.ConfigPath) == "" {
opts.ConfigPath = "config.json"
}
if strings.TrimSpace(opts.OutputDir) == "" {
opts.OutputDir = "artifacts/testsuite"
}
if opts.Timeout <= 0 {
opts.Timeout = 120 * time.Second
}
if opts.Retries < 0 {
opts.Retries = 0
}
adminKey := strings.TrimSpace(opts.AdminKey)
if adminKey == "" {
adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY"))
}
if adminKey == "" {
adminKey = "admin"
}
opts.AdminKey = adminKey
return &Runner{
opts: opts,
httpClient: &http.Client{
Timeout: 0,
},
runID: time.Now().UTC().Format("20060102T150405Z"),
adminKey: adminKey,
}, nil
}
func (r *Runner) runCase(ctx context.Context, c caseDef) {
caseDir := filepath.Join(r.runDir, "cases", c.ID)
_ = os.MkdirAll(caseDir, 0o755)
cc := &caseContext{
runner: r,
id: c.ID,
dir: caseDir,
startedAt: time.Now(),
traceIDsSet: map[string]struct{}{},
}
err := c.Run(ctx, cc)
duration := time.Since(cc.startedAt).Milliseconds()
if err != nil {
cc.assertions = append(cc.assertions, assertionResult{
Name: "case_error",
Passed: false,
Detail: err.Error(),
})
}
passed := err == nil
for _, a := range cc.assertions {
if !a.Passed {
passed = false
break
}
}
traceIDs := make([]string, 0, len(cc.traceIDsSet))
for t := range cc.traceIDsSet {
traceIDs = append(traceIDs, t)
}
sort.Strings(traceIDs)
statuses := uniqueStatusCodes(cc.responses)
cs := caseResult{
CaseID: c.ID,
Passed: passed,
DurationMS: duration,
TraceIDs: traceIDs,
StatusCodes: statuses,
ArtifactPath: caseDir,
Assertions: cc.assertions,
}
if err != nil {
cs.Error = err.Error()
}
_ = cc.flushArtifacts(cs)
r.results = append(r.results, cs)
}

View File

@@ -0,0 +1,20 @@
package testsuite
import (
"os"
"strings"
"time"
)
func DefaultOptions() Options {
return Options{
ConfigPath: "config.json",
AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")),
OutputDir: "artifacts/testsuite",
Port: 0,
Timeout: 120 * time.Second,
Retries: 2,
NoPreflight: false,
MaxKeepRuns: 5,
}
}

View File

@@ -0,0 +1,261 @@
package testsuite
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
)
func (r *Runner) prepareRunDir() error {
r.runDir = filepath.Join(r.opts.OutputDir, r.runID)
if err := os.MkdirAll(r.runDir, 0o755); err != nil {
return err
}
if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil {
return err
}
r.serverLog = filepath.Join(r.runDir, "server.log")
r.preflightLog = filepath.Join(r.runDir, "preflight.log")
return nil
}
// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns.
// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order.
func (r *Runner) pruneOldRuns() error {
keep := r.opts.MaxKeepRuns
if keep <= 0 {
return nil // 0 or negative means no pruning
}
entries, err := os.ReadDir(r.opts.OutputDir)
if err != nil {
return err
}
// Collect only directories (each run is a directory).
var runDirs []string
for _, e := range entries {
if !e.IsDir() {
continue
}
runDirs = append(runDirs, e.Name())
}
sort.Strings(runDirs)
if len(runDirs) <= keep {
return nil
}
// Remove oldest runs (those at the beginning of the sorted list).
toRemove := runDirs[:len(runDirs)-keep]
var errs []string
for _, name := range toRemove {
dirPath := filepath.Join(r.opts.OutputDir, name)
if err := os.RemoveAll(dirPath); err != nil {
errs = append(errs, fmt.Sprintf("remove %s: %v", name, err))
} else {
fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name)
}
}
if len(errs) > 0 {
return errors.New(strings.Join(errs, "; "))
}
return nil
}
func (r *Runner) runPreflight(ctx context.Context) error {
steps := [][]string{
{"go", "test", "./...", "-count=1"},
{"node", "--check", "api/chat-stream.js"},
{"node", "--check", "api/helpers/stream-tool-sieve.js"},
{"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"},
{"npm", "run", "build", "--prefix", "webui"},
}
f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return err
}
defer f.Close()
for _, step := range steps {
if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil {
return err
}
cmd := exec.CommandContext(ctx, step[0], step[1:]...)
cmd.Stdout = f
cmd.Stderr = f
if err := cmd.Run(); err != nil {
return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err)
}
}
return nil
}
func (r *Runner) prepareConfigIsolation() error {
abs, err := filepath.Abs(r.opts.ConfigPath)
if err != nil {
return err
}
r.originalConfigPath = abs
raw, err := os.ReadFile(abs)
if err != nil {
return err
}
sum := sha256.Sum256(raw)
r.originalConfigHash = hex.EncodeToString(sum[:])
tmpDir := filepath.Join(r.runDir, "tmp")
if err := os.MkdirAll(tmpDir, 0o755); err != nil {
return err
}
r.configCopyPath = filepath.Join(tmpDir, "config.json")
if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil {
return err
}
var cfg runConfig
if err := json.Unmarshal(raw, &cfg); err != nil {
return fmt.Errorf("parse config failed: %w", err)
}
r.configRaw = cfg
if len(cfg.Keys) > 0 {
r.apiKey = strings.TrimSpace(cfg.Keys[0])
}
for _, acc := range cfg.Accounts {
id := strings.TrimSpace(acc.Email)
if id == "" {
id = strings.TrimSpace(acc.Mobile)
}
if id != "" {
r.accountID = id
break
}
}
return nil
}
func (r *Runner) startServer(ctx context.Context) error {
port := r.opts.Port
if port <= 0 {
p, err := findFreePort()
if err != nil {
return err
}
port = p
}
r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port)
logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return err
}
r.serverLogFd = logFd
cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api")
cmd.Stdout = logFd
cmd.Stderr = logFd
cmd.Env = prepareServerEnv(os.Environ(), map[string]string{
"PORT": strconv.Itoa(port),
"DS2API_CONFIG_PATH": r.configCopyPath,
"DS2API_AUTO_BUILD_WEBUI": "false",
"DS2API_CONFIG_JSON": "",
"CONFIG_JSON": "",
})
if err := cmd.Start(); err != nil {
_ = logFd.Close()
return err
}
r.serverCmd = cmd
deadline := time.Now().Add(90 * time.Second)
for time.Now().Before(deadline) {
if r.ping("/healthz") == nil && r.ping("/readyz") == nil {
return nil
}
time.Sleep(500 * time.Millisecond)
}
return errors.New("server readiness timeout")
}
func (r *Runner) stopServer() error {
var errs []string
if r.serverCmd != nil && r.serverCmd.Process != nil {
_ = r.serverCmd.Process.Signal(os.Interrupt)
done := make(chan error, 1)
go func() { done <- r.serverCmd.Wait() }()
select {
case <-time.After(5 * time.Second):
_ = r.serverCmd.Process.Kill()
<-done
case <-done:
}
}
if r.serverLogFd != nil {
if err := r.serverLogFd.Close(); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
return errors.New(strings.Join(errs, "; "))
}
return nil
}
func (r *Runner) ping(path string) error {
resp, err := r.httpClient.Get(r.baseURL + path)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status=%d", resp.StatusCode)
}
return nil
}
func (r *Runner) prepareAuth(ctx context.Context) error {
reqBody := map[string]any{
"admin_key": r.adminKey,
"expire_hours": 24,
}
resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body))
}
var m map[string]any
if err := json.Unmarshal(resp.Body, &m); err != nil {
return err
}
token, _ := m["token"].(string)
if strings.TrimSpace(token) == "" {
return errors.New("empty admin jwt token")
}
r.adminJWT = token
return nil
}
func (r *Runner) ensureOriginalConfigUntouched() error {
raw, err := os.ReadFile(r.originalConfigPath)
if err != nil {
return err
}
sum := sha256.Sum256(raw)
current := hex.EncodeToString(sum[:])
if current != r.originalConfigHash {
return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath)
}
return nil
}

View File

@@ -0,0 +1,217 @@
package testsuite
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
func (cc *caseContext) assert(name string, ok bool, detail string) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.assertions = append(cc.assertions, assertionResult{
Name: name,
Passed: ok,
Detail: detail,
})
}
func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) {
retries := cc.runner.opts.Retries
if !spec.Retryable {
retries = 0
}
var lastErr error
for attempt := 1; attempt <= retries+1; attempt++ {
resp, err := cc.requestOnce(ctx, spec, attempt)
if err == nil && resp.StatusCode < 500 {
return resp, nil
}
if err != nil {
lastErr = err
} else if resp.StatusCode >= 500 {
lastErr = fmt.Errorf("status=%d", resp.StatusCode)
}
if attempt <= retries {
sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond
time.Sleep(sleep)
}
}
return nil, lastErr
}
func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) {
cc.mu.Lock()
cc.seq++
seq := cc.seq
traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq)
cc.traceIDsSet[traceID] = struct{}{}
cc.mu.Unlock()
fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
if err != nil {
return nil, err
}
headers := map[string]string{}
for k, v := range spec.Headers {
headers[k] = v
}
headers["X-Ds2-Test-Trace"] = traceID
var bodyBytes []byte
var bodyAny any
if spec.Body != nil {
b, err := json.Marshal(spec.Body)
if err != nil {
return nil, err
}
bodyBytes = b
bodyAny = spec.Body
headers["Content-Type"] = "application/json"
}
cc.mu.Lock()
cc.requests = append(cc.requests, requestLog{
Seq: seq,
Attempt: attempt,
TraceID: traceID,
Method: spec.Method,
URL: fullURL,
Headers: headers,
Body: bodyAny,
Timestamp: time.Now().Format(time.RFC3339Nano),
})
cc.mu.Unlock()
reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
if err != nil {
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
start := time.Now()
resp, err := cc.runner.httpClient.Do(req)
if err != nil {
cc.mu.Lock()
cc.responses = append(cc.responses, responseLog{
Seq: seq,
Attempt: attempt,
TraceID: traceID,
StatusCode: 0,
DurationMS: time.Since(start).Milliseconds(),
NetworkErr: err.Error(),
ReceivedAt: time.Now().Format(time.RFC3339Nano),
})
cc.mu.Unlock()
return nil, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
cc.mu.Lock()
cc.responses = append(cc.responses, responseLog{
Seq: seq,
Attempt: attempt,
TraceID: traceID,
StatusCode: resp.StatusCode,
Headers: resp.Header,
BodyText: string(body),
DurationMS: time.Since(start).Milliseconds(),
ReceivedAt: time.Now().Format(time.RFC3339Nano),
})
if spec.Stream {
cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL))
cc.streamRaw.Write(body)
cc.streamRaw.WriteString("\n\n")
}
cc.mu.Unlock()
return &responseResult{
StatusCode: resp.StatusCode,
Headers: resp.Header,
Body: body,
TraceID: traceID,
URL: fullURL,
}, nil
}
func (cc *caseContext) flushArtifacts(cs caseResult) error {
requestPath := filepath.Join(cc.dir, "request.json")
headersPath := filepath.Join(cc.dir, "response.headers")
bodyPath := filepath.Join(cc.dir, "response.body")
streamPath := filepath.Join(cc.dir, "stream.raw")
assertPath := filepath.Join(cc.dir, "assertions.json")
metaPath := filepath.Join(cc.dir, "meta.json")
if err := writeJSONFile(requestPath, cc.requests); err != nil {
return err
}
respHeaders := make([]map[string]any, 0, len(cc.responses))
respBodies := make([]map[string]any, 0, len(cc.responses))
for _, r := range cc.responses {
respHeaders = append(respHeaders, map[string]any{
"seq": r.Seq,
"attempt": r.Attempt,
"trace_id": r.TraceID,
"status_code": r.StatusCode,
"headers": r.Headers,
})
respBodies = append(respBodies, map[string]any{
"seq": r.Seq,
"attempt": r.Attempt,
"trace_id": r.TraceID,
"status_code": r.StatusCode,
"body_text": r.BodyText,
"network_error": r.NetworkErr,
"duration_ms": r.DurationMS,
})
}
if err := writeJSONFile(headersPath, respHeaders); err != nil {
return err
}
if err := writeJSONFile(bodyPath, respBodies); err != nil {
return err
}
if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil {
return err
}
if err := writeJSONFile(assertPath, cc.assertions); err != nil {
return err
}
meta := map[string]any{
"case_id": cs.CaseID,
"trace_id": strings.Join(cs.TraceIDs, ","),
"attempt": len(cc.responses),
"duration_ms": cs.DurationMS,
"status": map[bool]string{true: "passed", false: "failed"}[cs.Passed],
"status_codes": cs.StatusCodes,
"assertions": cs.Assertions,
"artifact_path": cs.ArtifactPath,
}
return writeJSONFile(metaPath, meta)
}
func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) {
cc := &caseContext{
runner: r,
id: "auth_prepare",
traceIDsSet: map[string]struct{}{},
}
return cc.request(ctx, requestSpec{
Method: method,
Path: path,
Headers: headers,
Body: body,
Retryable: true,
})
}

View File

@@ -0,0 +1,43 @@
package testsuite
import "context"
type caseDef struct {
ID string
Run func(context.Context, *caseContext) error
}
func (r *Runner) cases() []caseDef {
return []caseDef{
{ID: "healthz_ok", Run: r.caseHealthz},
{ID: "readyz_ok", Run: r.caseReadyz},
{ID: "models_openai", Run: r.caseModelsOpenAI},
{ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
{ID: "models_claude", Run: r.caseModelsClaude},
{ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
{ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
{ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
{ID: "chat_stream_basic", Run: r.caseChatStream},
{ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
{ID: "responses_stream_basic", Run: r.caseResponsesStream},
{ID: "embeddings_contract", Run: r.caseEmbeddings},
{ID: "reasoner_stream", Run: r.caseReasonerStream},
{ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
{ID: "toolcall_stream", Run: r.caseToolcallStream},
{ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream},
{ID: "anthropic_messages_stream", Run: r.caseAnthropicStream},
{ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens},
{ID: "admin_account_test_single", Run: r.caseAdminAccountTest},
{ID: "concurrency_burst", Run: r.caseConcurrencyBurst},
{ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit},
{ID: "stream_abort_release", Run: r.caseStreamAbortRelease},
{ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed},
{ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity},
{ID: "error_contract_invalid_model", Run: r.caseInvalidModel},
{ID: "error_contract_missing_messages", Run: r.caseMissingMessages},
{ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized},
{ID: "config_write_isolated", Run: r.caseConfigWriteIsolated},
{ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount},
{ID: "error_contract_invalid_key", Run: r.caseInvalidKey},
}
}

View File

@@ -0,0 +1,97 @@
package testsuite
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"time"
)
func (r *Runner) writeSummary(start, end time.Time) error {
passed := 0
failed := 0
for _, cs := range r.results {
if cs.Passed {
passed++
} else {
failed++
}
}
summary := runSummary{
RunID: r.runID,
StartedAt: start.Format(time.RFC3339Nano),
EndedAt: end.Format(time.RFC3339Nano),
DurationMS: end.Sub(start).Milliseconds(),
Stats: map[string]any{
"total": len(r.results),
"passed": passed,
"failed": failed,
},
Environment: map[string]any{
"go_version": runtime.Version(),
"os": runtime.GOOS,
"arch": runtime.GOARCH,
"base_url": r.baseURL,
"config_source": r.originalConfigPath,
"config_isolated": r.configCopyPath,
"server_log": r.serverLog,
"preflight_log": r.preflightLog,
"retries": r.opts.Retries,
"timeout_seconds": int(r.opts.Timeout.Seconds()),
},
Cases: r.results,
Warnings: r.warnings,
}
if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil {
return err
}
return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644)
}
func (r *Runner) summaryMarkdown(s runSummary) string {
var b strings.Builder
b.WriteString("# DS2API Live Testsuite Summary\n\n")
b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n")
fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID)
fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt)
fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt)
fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS)
fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"])
if len(s.Warnings) > 0 {
b.WriteString("## Warnings\n\n")
for _, w := range s.Warnings {
fmt.Fprintf(&b, "- %s\n", w)
}
b.WriteString("\n")
}
b.WriteString("## Failed Cases\n\n")
hasFailed := false
for _, c := range s.Cases {
if c.Passed {
continue
}
hasFailed = true
fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error)
if len(c.TraceIDs) > 0 {
fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", "))
fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log"))
}
fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath)
}
if !hasFailed {
b.WriteString("- none\n")
}
b.WriteString("\n## Case Table\n\n")
b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n")
b.WriteString("|---|---:|---:|---|---|\n")
for _, c := range s.Cases {
status := "PASS"
if !c.Passed {
status = "FAIL"
}
fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath)
}
return b.String()
}

View File

@@ -0,0 +1,202 @@
package testsuite
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"os"
"sort"
"strings"
)
func parseSSEFrames(body []byte) ([]map[string]any, bool) {
lines := strings.Split(string(body), "\n")
frames := make([]map[string]any, 0, len(lines))
done := false
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" {
continue
}
if payload == "[DONE]" {
done = true
continue
}
var m map[string]any
if err := json.Unmarshal([]byte(payload), &m); err == nil {
frames = append(frames, m)
}
}
return frames, done
}
func parseClaudeStreamEvents(body []byte) []string {
events := []string{}
seen := map[string]bool{}
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" {
continue
}
var m map[string]any
if err := json.Unmarshal([]byte(payload), &m); err != nil {
continue
}
t := asString(m["type"])
if t == "" || seen[t] {
continue
}
seen[t] = true
events = append(events, t)
}
return events
}
func extractModelIDs(body []byte) []string {
var m map[string]any
if err := json.Unmarshal(body, &m); err != nil {
return nil
}
out := []string{}
data, _ := m["data"].([]any)
for _, it := range data {
item, _ := it.(map[string]any)
id := asString(item["id"])
if id != "" {
out = append(out, id)
}
}
return out
}
func withTraceQuery(rawURL, traceID string) (string, error) {
u, err := url.Parse(rawURL)
if err != nil {
return "", err
}
q := u.Query()
q.Set("__trace_id", traceID)
u.RawQuery = q.Encode()
return u.String(), nil
}
func writeJSONFile(path string, v any) error {
b, err := json.MarshalIndent(v, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, b, 0o644)
}
func prepareServerEnv(base []string, overrides map[string]string) []string {
out := make([]string, 0, len(base)+len(overrides))
skip := map[string]struct{}{}
for k := range overrides {
skip[k] = struct{}{}
}
for _, e := range base {
parts := strings.SplitN(e, "=", 2)
if len(parts) != 2 {
continue
}
if _, ok := skip[parts[0]]; ok {
continue
}
out = append(out, e)
}
for k, v := range overrides {
out = append(out, k+"="+v)
}
return out
}
func findFreePort() (int, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
defer ln.Close()
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
return 0, errors.New("failed to detect tcp port")
}
return addr.Port, nil
}
func uniqueStatusCodes(in []responseLog) []int {
set := map[int]struct{}{}
for _, it := range in {
if it.StatusCode > 0 {
set[it.StatusCode] = struct{}{}
}
}
out := make([]int, 0, len(set))
for k := range set {
out = append(out, k)
}
sort.Ints(out)
return out
}
func has5xx(dist map[int]int) (int, bool) {
for k := range dist {
if k >= 500 {
return k, true
}
}
return 0, false
}
func sanitizeID(s string) string {
s = strings.ReplaceAll(s, ":", "_")
s = strings.ReplaceAll(s, "/", "_")
s = strings.ReplaceAll(s, " ", "_")
return s
}
func asString(v any) string {
if v == nil {
return ""
}
switch x := v.(type) {
case string:
return strings.TrimSpace(x)
default:
return strings.TrimSpace(fmt.Sprintf("%v", v))
}
}
func toInt(v any) int {
switch x := v.(type) {
case float64:
return int(x)
case float32:
return int(x)
case int:
return x
case int64:
return int(x)
default:
return 0
}
}
func contains(xs []string, target string) bool {
for _, x := range xs {
if x == target {
return true
}
}
return false
}

View File

@@ -0,0 +1,138 @@
package util
import (
"regexp"
"strings"
)
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
func buildToolCallCandidates(text string) []string {
trimmed := strings.TrimSpace(text)
candidates := []string{trimmed}
// fenced code block candidates: ```json ... ```
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
if len(match) >= 2 {
candidates = append(candidates, strings.TrimSpace(match[1]))
}
}
// best-effort extraction around "tool_calls" key in mixed text payloads.
candidates = append(candidates, extractToolCallObjects(trimmed)...)
// best-effort object slice: from first '{' to last '}'
first := strings.Index(trimmed, "{")
last := strings.LastIndex(trimmed, "}")
if first >= 0 && last > first {
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
}
// legacy regex extraction fallback
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
}
uniq := make([]string, 0, len(candidates))
seen := map[string]struct{}{}
for _, c := range candidates {
if c == "" {
continue
}
if _, ok := seen[c]; ok {
continue
}
seen[c] = struct{}{}
uniq = append(uniq, c)
}
return uniq
}
func extractToolCallObjects(text string) []string {
if text == "" {
return nil
}
lower := strings.ToLower(text)
out := []string{}
offset := 0
for {
idx := strings.Index(lower[offset:], "tool_calls")
if idx < 0 {
break
}
idx += offset
start := strings.LastIndex(text[:idx], "{")
for start >= 0 {
candidate, end, ok := extractJSONObject(text, start)
if ok {
// Move forward to avoid repeatedly matching the same object.
offset = end
out = append(out, strings.TrimSpace(candidate))
break
}
start = strings.LastIndex(text[:start], "{")
}
if start < 0 {
offset = idx + len("tool_calls")
}
}
return out
}
func extractJSONObject(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
return text[start : i+1], i + 1, true
}
}
}
return "", 0, false
}
func looksLikeToolExampleContext(text string) bool {
t := strings.ToLower(strings.TrimSpace(text))
if t == "" {
return false
}
return strings.Contains(t, "```")
}
func stripFencedCodeBlocks(text string) string {
if strings.TrimSpace(text) == "" {
return ""
}
return fencedBlockPattern.ReplaceAllString(text, " ")
}

View File

@@ -0,0 +1,41 @@
package util
import (
"encoding/json"
"strings"
"github.com/google/uuid"
)
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
out := make([]map[string]any, 0, len(calls))
for _, c := range calls {
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"index": i,
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}

View File

@@ -2,16 +2,9 @@ package util
import (
"encoding/json"
"regexp"
"strings"
"github.com/google/uuid"
)
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
type ParsedToolCall struct {
Name string `json:"name"`
Input map[string]any `json:"input"`
@@ -102,47 +95,6 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
return out
}
func buildToolCallCandidates(text string) []string {
trimmed := strings.TrimSpace(text)
candidates := []string{trimmed}
// fenced code block candidates: ```json ... ```
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
if len(match) >= 2 {
candidates = append(candidates, strings.TrimSpace(match[1]))
}
}
// best-effort extraction around "tool_calls" key in mixed text payloads.
candidates = append(candidates, extractToolCallObjects(trimmed)...)
// best-effort object slice: from first '{' to last '}'
first := strings.Index(trimmed, "{")
last := strings.LastIndex(trimmed, "}")
if first >= 0 && last > first {
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
}
// legacy regex extraction fallback
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
}
uniq := make([]string, 0, len(candidates))
seen := map[string]struct{}{}
for _, c := range candidates {
if c == "" {
continue
}
if _, ok := seen[c]; ok {
continue
}
seen[c] = struct{}{}
uniq = append(uniq, c)
}
return uniq
}
func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
@@ -243,123 +195,3 @@ func parseToolCallInput(v any) map[string]any {
return map[string]any{}
}
}
func extractToolCallObjects(text string) []string {
if text == "" {
return nil
}
lower := strings.ToLower(text)
out := []string{}
offset := 0
for {
idx := strings.Index(lower[offset:], "tool_calls")
if idx < 0 {
break
}
idx += offset
start := strings.LastIndex(text[:idx], "{")
for start >= 0 {
candidate, end, ok := extractJSONObject(text, start)
if ok {
// Move forward to avoid repeatedly matching the same object.
offset = end
out = append(out, strings.TrimSpace(candidate))
break
}
start = strings.LastIndex(text[:start], "{")
}
if start < 0 {
offset = idx + len("tool_calls")
}
}
return out
}
func extractJSONObject(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
return text[start : i+1], i + 1, true
}
}
}
return "", 0, false
}
func looksLikeToolExampleContext(text string) bool {
t := strings.ToLower(strings.TrimSpace(text))
if t == "" {
return false
}
return strings.Contains(t, "```")
}
func stripFencedCodeBlocks(text string) string {
if strings.TrimSpace(text) == "" {
return ""
}
return fencedBlockPattern.ReplaceAllString(text, " ")
}
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
out := make([]map[string]any, 0, len(calls))
for _, c := range calls {
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}
func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"index": i,
"id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}