diff --git a/TESTING.md b/TESTING.md index 5540592..ce349ec 100644 --- a/TESTING.md +++ b/TESTING.md @@ -24,7 +24,7 @@ go test ./... ``` ```bash -node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js +node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js ``` ### 端到端测试 | End-to-End Tests @@ -39,7 +39,7 @@ node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js - `go test ./... -count=1`(单元测试) - `node --check api/chat-stream.js`(语法检查) - `node --check api/helpers/stream-tool-sieve.js`(语法检查) - - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js`(Node 流式拦截单测) + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测) - `npm run build --prefix webui`(WebUI 构建检查) 2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程 diff --git a/api/chat-stream.js b/api/chat-stream.js index 1a8e896..c4a4cd1 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -10,31 +10,14 @@ const { parseToolCalls, formatOpenAIStreamToolCalls, } = require('./helpers/stream-tool-sieve'); +const { + BASE_HEADERS, + SKIP_PATTERNS, + SKIP_EXACT_PATHS, +} = require('./shared/deepseek-constants'); const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; -const BASE_HEADERS = { - Host: 'chat.deepseek.com', - 'User-Agent': 'DeepSeek/1.6.11 Android/35', - Accept: 'application/json', - 'Content-Type': 'application/json', - 'x-client-platform': 'android', - 'x-client-version': '1.6.11', - 'x-client-locale': 'zh_CN', - 'accept-charset': 'UTF-8', -}; - -const SKIP_PATTERNS = [ - 'quasi_status', - 'elapsed_secs', - 'token_usage', - 'pending_fragment', - 'conversation_mode', - 'fragments/-1/status', - 'fragments/-2/status', - 'fragments/-3/status', -]; - module.exports = async function handler(req, res) { setCorsHeaders(res); if (req.method === 'OPTIONS') { @@ -725,7 +708,7 @@ function extractContentRecursive(items, defaultType) { } function shouldSkipPath(pathValue) { - if (pathValue === 'response/search_status') { + if (SKIP_EXACT_PATHS.has(pathValue)) { return true; } for (const p of SKIP_PATTERNS) { @@ -808,7 +791,16 @@ function estimateTokens(text) { if (!t) { return 0; } - const n = Math.floor(Array.from(t).length / 4); + let asciiChars = 0; + let nonASCIIChars = 0; + for (const ch of Array.from(t)) { + if (ch.charCodeAt(0) < 128) { + asciiChars += 1; + } else { + nonASCIIChars += 1; + } + } + const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13); return n < 1 ? 1 : n; } @@ -972,4 +964,5 @@ module.exports.__test = { resolveToolcallPolicy, normalizePreparedToolNames, boolDefaultTrue, + estimateTokens, }; diff --git a/api/compat/js_compat_test.js b/api/compat/js_compat_test.js new file mode 100644 index 0000000..9b03b00 --- /dev/null +++ b/api/compat/js_compat_test.js @@ -0,0 +1,60 @@ +'use strict'; + +const test = require('node:test'); +const assert = require('node:assert/strict'); +const fs = require('node:fs'); +const path = require('node:path'); + +const chatStream = require('../chat-stream'); +const { parseToolCalls } = require('../helpers/stream-tool-sieve'); + +const { parseChunkForContent, estimateTokens } = chatStream.__test; + +const compatRoot = path.resolve(__dirname, '../../tests/compat'); + +function readJSON(filePath) { + return JSON.parse(fs.readFileSync(filePath, 'utf8')); +} + +test('js compat: sse fixtures', () => { + const fixtureDir = path.join(compatRoot, 'fixtures', 'sse_chunks'); + const expectedDir = path.join(compatRoot, 'expected'); + const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort(); + assert.ok(files.length > 0); + + for (const file of files) { + const name = file.replace(/\.json$/i, ''); + const fixture = readJSON(path.join(fixtureDir, file)); + const expected = readJSON(path.join(expectedDir, `sse_${name}.json`)); + const got = parseChunkForContent(fixture.chunk, Boolean(fixture.thinking_enabled), fixture.current_type || 'text'); + assert.deepEqual(got.parts, expected.parts, `${name}: parts mismatch`); + assert.equal(got.finished, expected.finished, `${name}: finished mismatch`); + assert.equal(got.newType, expected.new_type, `${name}: newType mismatch`); + } +}); + +test('js compat: toolcall fixtures', () => { + const fixtureDir = path.join(compatRoot, 'fixtures', 'toolcalls'); + const expectedDir = path.join(compatRoot, 'expected'); + const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort(); + assert.ok(files.length > 0); + + for (const file of files) { + const name = file.replace(/\.json$/i, ''); + const fixture = readJSON(path.join(fixtureDir, file)); + const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`)); + const got = parseToolCalls(fixture.text, fixture.tool_names || []); + assert.deepEqual(got, expected.calls, `${name}: calls mismatch`); + } +}); + +test('js compat: token fixtures', () => { + const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json')); + const expected = readJSON(path.join(compatRoot, 'expected', 'token_cases.json')); + const expectedByName = new Map(expected.cases.map((c) => [c.name, c.tokens])); + for (const c of fixture.cases) { + assert.ok(expectedByName.has(c.name), `missing expected case: ${c.name}`); + const got = estimateTokens(c.text); + assert.equal(got, expectedByName.get(c.name), `${c.name}: tokens mismatch`); + } +}); diff --git a/api/shared/deepseek-constants.js b/api/shared/deepseek-constants.js new file mode 100644 index 0000000..1ec74f1 --- /dev/null +++ b/api/shared/deepseek-constants.js @@ -0,0 +1,66 @@ +'use strict'; + +const fs = require('fs'); +const path = require('path'); + +const DEFAULT_BASE_HEADERS = Object.freeze({ + Host: 'chat.deepseek.com', + 'User-Agent': 'DeepSeek/1.6.11 Android/35', + Accept: 'application/json', + 'Content-Type': 'application/json', + 'x-client-platform': 'android', + 'x-client-version': '1.6.11', + 'x-client-locale': 'zh_CN', + 'accept-charset': 'UTF-8', +}); + +const DEFAULT_SKIP_PATTERNS = Object.freeze([ + 'quasi_status', + 'elapsed_secs', + 'token_usage', + 'pending_fragment', + 'conversation_mode', + 'fragments/-1/status', + 'fragments/-2/status', + 'fragments/-3/status', +]); + +const DEFAULT_SKIP_EXACT_PATHS = Object.freeze([ + 'response/search_status', +]); + +function loadSharedConstants() { + const sharedPath = path.resolve(__dirname, '../../internal/deepseek/constants_shared.json'); + try { + const raw = fs.readFileSync(sharedPath, 'utf8'); + const parsed = JSON.parse(raw); + const baseHeaders = parsed && typeof parsed.base_headers === 'object' && !Array.isArray(parsed.base_headers) + ? { ...DEFAULT_BASE_HEADERS, ...parsed.base_headers } + : { ...DEFAULT_BASE_HEADERS }; + const skipPatterns = Array.isArray(parsed && parsed.skip_contains_patterns) + ? parsed.skip_contains_patterns.filter((v) => typeof v === 'string' && v !== '') + : [...DEFAULT_SKIP_PATTERNS]; + const skipExactPaths = Array.isArray(parsed && parsed.skip_exact_paths) + ? parsed.skip_exact_paths.filter((v) => typeof v === 'string' && v !== '') + : [...DEFAULT_SKIP_EXACT_PATHS]; + return { + baseHeaders, + skipPatterns, + skipExactPaths, + }; + } catch (_err) { + return { + baseHeaders: { ...DEFAULT_BASE_HEADERS }, + skipPatterns: [...DEFAULT_SKIP_PATTERNS], + skipExactPaths: [...DEFAULT_SKIP_EXACT_PATHS], + }; + } +} + +const shared = loadSharedConstants(); + +module.exports = { + BASE_HEADERS: Object.freeze(shared.baseHeaders), + SKIP_PATTERNS: Object.freeze(shared.skipPatterns), + SKIP_EXACT_PATHS: new Set(shared.skipExactPaths), +}; diff --git a/internal/account/pool.go b/internal/account/pool.go index 665bcee..12d8874 100644 --- a/internal/account/pool.go +++ b/internal/account/pool.go @@ -20,13 +20,18 @@ type Pool struct { maxInflightPerAccount int recommendedConcurrency int maxQueueSize int + globalMaxInflight int } func NewPool(store *config.Store) *Pool { + maxPer := 2 + if store != nil { + maxPer = store.RuntimeAccountMaxInflight() + } p := &Pool{ store: store, inUse: map[string]int{}, - maxInflightPerAccount: maxInflightFromEnv(), + maxInflightPerAccount: maxPer, } p.Reset() return p @@ -49,8 +54,18 @@ func (p *Pool) Reset() { ids = append(ids, id) } } + if p.store != nil { + p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight() + } else { + p.maxInflightPerAccount = maxInflightFromEnv() + } recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount) queueLimit := maxQueueFromEnv(recommended) + globalLimit := recommended + if p.store != nil { + queueLimit = p.store.RuntimeAccountMaxQueue(recommended) + globalLimit = p.store.RuntimeGlobalMaxInflight(recommended) + } p.mu.Lock() defer p.mu.Unlock() p.drainWaitersLocked() @@ -58,10 +73,12 @@ func (p *Pool) Reset() { p.inUse = map[string]int{} p.recommendedConcurrency = recommended p.maxQueueSize = queueLimit + p.globalMaxInflight = globalLimit config.Logger.Info( "[init_account_queue] initialized", "total", len(ids), "max_inflight_per_account", p.maxInflightPerAccount, + "global_max_inflight", p.globalMaxInflight, "recommended_concurrency", p.recommendedConcurrency, "max_queue_size", p.maxQueueSize, ) @@ -109,7 +126,7 @@ func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[strin func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) { if target != "" { - if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount { + if exclude[target] || !p.canAcquireIDLocked(target) { return config.Account{}, false } acc, ok := p.store.FindAccount(target) @@ -133,7 +150,7 @@ func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Acc func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) { for i := 0; i < len(p.queue); i++ { id := p.queue[i] - if exclude[id] || p.inUse[id] >= p.maxInflightPerAccount { + if exclude[id] || !p.canAcquireIDLocked(id) { continue } acc, ok := p.store.FindAccount(id) @@ -205,12 +222,35 @@ func (p *Pool) Status() map[string]any { "available_accounts": available, "in_use_accounts": inUseAccounts, "max_inflight_per_account": p.maxInflightPerAccount, + "global_max_inflight": p.globalMaxInflight, "recommended_concurrency": p.recommendedConcurrency, "waiting": len(p.waiters), "max_queue_size": p.maxQueueSize, } } +func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) { + if maxInflightPerAccount <= 0 { + maxInflightPerAccount = 1 + } + if maxQueueSize < 0 { + maxQueueSize = 0 + } + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts()) + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount + } + } + p.mu.Lock() + defer p.mu.Unlock() + p.maxInflightPerAccount = maxInflightPerAccount + p.maxQueueSize = maxQueueSize + p.globalMaxInflight = globalMaxInflight + p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount) + p.notifyWaiterLocked() +} + func maxInflightFromEnv() int { for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { raw := strings.TrimSpace(os.Getenv(key)) @@ -300,3 +340,24 @@ func maxQueueFromEnv(defaultSize int) int { } return defaultSize } + +func (p *Pool) canAcquireIDLocked(accountID string) bool { + if accountID == "" { + return false + } + if p.inUse[accountID] >= p.maxInflightPerAccount { + return false + } + if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight { + return false + } + return true +} + +func (p *Pool) currentInUseLocked() int { + total := 0 + for _, n := range p.inUse { + total += n + } + return total +} diff --git a/internal/adapter/claude/convert.go b/internal/adapter/claude/convert.go new file mode 100644 index 0000000..dbb5e1a --- /dev/null +++ b/internal/adapter/claude/convert.go @@ -0,0 +1,11 @@ +package claude + +import ( + "ds2api/internal/claudeconv" +) + +const defaultClaudeModel = "claude-sonnet-4-5" + +func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any { + return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel) +} diff --git a/internal/adapter/claude/deps.go b/internal/adapter/claude/deps.go new file mode 100644 index 0000000..73203b2 --- /dev/null +++ b/internal/adapter/claude/deps.go @@ -0,0 +1,29 @@ +package claude + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ClaudeMapping() map[string]string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/claude/deps_injection_test.go b/internal/adapter/claude/deps_injection_test.go new file mode 100644 index 0000000..39dfc2f --- /dev/null +++ b/internal/adapter/claude/deps_injection_test.go @@ -0,0 +1,33 @@ +package claude + +import "testing" + +type mockClaudeConfig struct { + m map[string]string +} + +func (m mockClaudeConfig) ClaudeMapping() map[string]string { return m.m } + +func TestNormalizeClaudeRequestUsesConfigInterfaceMapping(t *testing.T) { + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + out, err := normalizeClaudeRequest(mockClaudeConfig{ + m: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner-search", + }, + }, req) + if err != nil { + t.Fatalf("normalizeClaudeRequest error: %v", err) + } + if out.Standard.ResolvedModel != "deepseek-reasoner-search" { + t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel) + } + if !out.Standard.Thinking || !out.Standard.Search { + t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search) + } +} diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index bac315f..282b569 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -13,7 +13,9 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/deepseek" + claudefmt "ds2api/internal/format/claude" "ds2api/internal/sse" + streamengine "ds2api/internal/stream" "ds2api/internal/util" ) @@ -21,9 +23,9 @@ import ( var writeJSON = util.WriteJSON type Handler struct { - Store *config.Store - Auth *auth.Resolver - DS *deepseek.Client + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller } var ( @@ -98,7 +100,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { return } result := sse.CollectStream(resp, stdReq.Thinking, true) - respBody := util.BuildClaudeMessageResponse( + respBody := claudefmt.BuildMessageResponse( fmt.Sprintf("msg_%d", time.Now().UnixNano()), stdReq.ResponseModel, norm.NormalizedMessages, @@ -169,279 +171,38 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } - send := func(event string, v any) { - b, _ := json.Marshal(v) - _, _ = w.Write([]byte("event: ")) - _, _ = w.Write([]byte(event)) - _, _ = w.Write([]byte("\n")) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - sendError := func(message string) { - msg := strings.TrimSpace(message) - if msg == "" { - msg = "upstream stream error" - } - send("error", map[string]any{ - "type": "error", - "error": map[string]any{ - "type": "api_error", - "message": msg, - "code": "internal_error", - "param": nil, - }, - }) - } - messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) - inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages)) - send("message_start", map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": messageID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, - }, - }) + streamRuntime := newClaudeStreamRuntime( + w, + rc, + canFlush, + model, + messages, + thinkingEnabled, + searchEnabled, + toolNames, + ) + streamRuntime.sendMessageStart() initialType := "text" if thinkingEnabled { initialType = "thinking" } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - bufferToolContent := len(toolNames) > 0 - hasContent := false - lastContent := time.Now() - keepaliveCount := 0 - - thinking := strings.Builder{} - text := strings.Builder{} - - nextBlockIndex := 0 - thinkingBlockOpen := false - thinkingBlockIndex := -1 - textBlockOpen := false - textBlockIndex := -1 - ended := false - - closeThinkingBlock := func() { - if !thinkingBlockOpen { - return - } - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": thinkingBlockIndex, - }) - thinkingBlockOpen = false - thinkingBlockIndex = -1 - } - closeTextBlock := func() { - if !textBlockOpen { - return - } - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": textBlockIndex, - }) - textBlockOpen = false - textBlockIndex = -1 - } - - finalize := func(stopReason string) { - if ended { - return - } - ended = true - - closeThinkingBlock() - closeTextBlock() - - finalThinking := thinking.String() - finalText := text.String() - - if bufferToolContent { - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - idx := nextBlockIndex + i - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), - "name": tc.Name, - "input": tc.Input, - }, - }) - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - nextBlockIndex += len(detected) - } else if finalText != "" { - idx := nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": idx, - "delta": map[string]any{ - "type": "text_delta", - "text": finalText, - }, - }) - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - } - - outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) - send("message_delta", map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": stopReason, - "stop_sequence": nil, - }, - "usage": map[string]any{ - "output_tokens": outputTokens, - }, - }) - send("message_stop", map[string]any{"type": "message_stop"}) - } - - pingTicker := time.NewTicker(claudeStreamPingInterval) - defer pingTicker.Stop() - - for { - select { - case <-r.Context().Done(): - return - case <-pingTicker.C: - if !hasContent { - keepaliveCount++ - if keepaliveCount >= claudeStreamMaxKeepaliveCnt { - finalize("end_turn") - return - } - } - if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout { - finalize("end_turn") - return - } - send("ping", map[string]any{"type": "ping"}) - case parsed, ok := <-parsedLines: - if !ok { - if err := <-done; err != nil { - sendError(err.Error()) - return - } - finalize("end_turn") - return - } - if !parsed.Parsed { - continue - } - if parsed.ErrorMessage != "" { - sendError(parsed.ErrorMessage) - return - } - if parsed.Stop { - finalize("end_turn") - return - } - - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { - continue - } - - hasContent = true - lastContent = time.Now() - keepaliveCount = 0 - - if p.Type == "thinking" { - if !thinkingEnabled { - continue - } - thinking.WriteString(p.Text) - closeTextBlock() - if !thinkingBlockOpen { - thinkingBlockIndex = nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": thinkingBlockIndex, - "content_block": map[string]any{ - "type": "thinking", - "thinking": "", - }, - }) - thinkingBlockOpen = true - } - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": thinkingBlockIndex, - "delta": map[string]any{ - "type": "thinking_delta", - "thinking": p.Text, - }, - }) - continue - } - - text.WriteString(p.Text) - if bufferToolContent { - continue - } - closeThinkingBlock() - if !textBlockOpen { - textBlockIndex = nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": textBlockIndex, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - textBlockOpen = true - } - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": textBlockIndex, - "delta": map[string]any{ - "type": "text_delta", - "text": p.Text, - }, - }) - } - } - } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: claudeStreamPingInterval, + IdleTimeout: claudeStreamIdleTimeout, + MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendPing() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: streamRuntime.onFinalize, + }) } func writeClaudeError(w http.ResponseWriter, status int, message string) { diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go index de97c6a..cdbb675 100644 --- a/internal/adapter/claude/standard_request.go +++ b/internal/adapter/claude/standard_request.go @@ -5,6 +5,7 @@ import ( "strings" "ds2api/internal/config" + "ds2api/internal/deepseek" "ds2api/internal/util" ) @@ -13,7 +14,7 @@ type claudeNormalizedRequest struct { NormalizedMessages []any } -func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) { +func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) { model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { @@ -30,14 +31,14 @@ func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNorm payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) } - dsPayload := util.ConvertClaudeToDeepSeek(payload, store) + dsPayload := convertClaudeToDeepSeek(payload, store) dsModel, _ := dsPayload["model"].(string) thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) if !ok { thinkingEnabled = false searchEnabled = false } - finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"])) toolNames := extractClaudeToolNames(toolsRequested) return claudeNormalizedRequest{ diff --git a/internal/adapter/claude/stream_runtime.go b/internal/adapter/claude/stream_runtime.go new file mode 100644 index 0000000..01e07a9 --- /dev/null +++ b/internal/adapter/claude/stream_runtime.go @@ -0,0 +1,308 @@ +package claude + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type claudeStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + toolNames []string + messages []any + + thinkingEnabled bool + searchEnabled bool + bufferToolContent bool + + messageID string + thinking strings.Builder + text strings.Builder + + nextBlockIndex int + thinkingBlockOpen bool + thinkingBlockIndex int + textBlockOpen bool + textBlockIndex int + ended bool + upstreamErr string +} + +func newClaudeStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + messages []any, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *claudeStreamRuntime { + return &claudeStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + messages: messages, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: len(toolNames) > 0, + toolNames: toolNames, + messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), + thinkingBlockIndex: -1, + textBlockIndex: -1, + } +} + +func (s *claudeStreamRuntime) send(event string, v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("event: ")) + _, _ = s.w.Write([]byte(event)) + _, _ = s.w.Write([]byte("\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *claudeStreamRuntime) sendError(message string) { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "upstream stream error" + } + s.send("error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": msg, + "code": "internal_error", + "param": nil, + }, + }) +} + +func (s *claudeStreamRuntime) sendPing() { + s.send("ping", map[string]any{"type": "ping"}) +} + +func (s *claudeStreamRuntime) sendMessageStart() { + inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) + s.send("message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": s.messageID, + "type": "message", + "role": "assistant", + "model": s.model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, + }, + }) +} + +func (s *claudeStreamRuntime) closeThinkingBlock() { + if !s.thinkingBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.thinkingBlockIndex, + }) + s.thinkingBlockOpen = false + s.thinkingBlockIndex = -1 +} + +func (s *claudeStreamRuntime) closeTextBlock() { + if !s.textBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.textBlockIndex, + }) + s.textBlockOpen = false + s.textBlockIndex = -1 +} + +func (s *claudeStreamRuntime) finalize(stopReason string) { + if s.ended { + return + } + s.ended = true + + s.closeThinkingBlock() + s.closeTextBlock() + + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferToolContent { + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + idx := s.nextBlockIndex + i + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), + "name": tc.Name, + "input": tc.Input, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + s.nextBlockIndex += len(detected) + } else if finalText != "" { + idx := s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "text_delta", + "text": finalText, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + } + + outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + s.send("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]any{ + "output_tokens": outputTokens, + }, + }) + s.send("message_stop", map[string]any{"type": "message_stop"}) +} + +func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ErrorMessage != "" { + s.upstreamErr = parsed.ErrorMessage + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.closeTextBlock() + if !s.thinkingBlockOpen { + s.thinkingBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.thinkingBlockIndex, + "content_block": map[string]any{ + "type": "thinking", + "thinking": "", + }, + }) + s.thinkingBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.thinkingBlockIndex, + "delta": map[string]any{ + "type": "thinking_delta", + "thinking": p.Text, + }, + }) + continue + } + + s.text.WriteString(p.Text) + if s.bufferToolContent { + continue + } + s.closeThinkingBlock() + if !s.textBlockOpen { + s.textBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.textBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.textBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.textBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": p.Text, + }, + }) + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} + +func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) { + if string(reason) == "upstream_error" { + s.sendError(s.upstreamErr) + return + } + if scannerErr != nil { + s.sendError(scannerErr.Error()) + return + } + s.finalize("end_turn") +} diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go new file mode 100644 index 0000000..0e64bc5 --- /dev/null +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -0,0 +1,237 @@ +package openai + +import ( + "encoding/json" + "net/http" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type chatStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + completionID string + created int64 + model string + finalPrompt string + toolNames []string + + thinkingEnabled bool + searchEnabled bool + + firstChunkSent bool + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + + toolSieve toolStreamSieveState + streamToolCallIDs map[int]string + thinking strings.Builder + text strings.Builder +} + +func newChatStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + completionID string, + created int64, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, +) *chatStreamRuntime { + return &chatStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + completionID: completionID, + created: created, + model: model, + finalPrompt: finalPrompt, + toolNames: toolNames, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + } +} + +func (s *chatStreamRuntime) sendKeepAlive() { + if !s.canFlush { + return + } + _, _ = s.w.Write([]byte(": keep-alive\n\n")) + _ = s.rc.Flush() +} + +func (s *chatStreamRuntime) sendChunk(v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *chatStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *chatStreamRuntime) finalize(finishReason string) { + finalThinking := s.thinking.String() + finalText := s.text.String() + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 && !s.toolCallsEmitted { + finishReason = "tool_calls" + delta := map[string]any{ + "tool_calls": util.FormatOpenAIStreamToolCalls(detected), + } + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, + nil, + )) + } else if s.bufferToolContent { + for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) { + if evt.Content == "" { + continue + } + delta := map[string]any{ + "content": evt.Content, + } + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, + nil, + )) + } + } + + if len(detected) > 0 || s.toolCallsEmitted { + finishReason = "tool_calls" + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)}, + openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText), + )) + s.sendDone() +} + +func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested} + } + + newChoices := make([]map[string]any, 0, len(parsed.Parts)) + contentSeen := false + for _, p := range parsed.Parts { + if s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Text == "" { + continue + } + contentSeen = true + delta := map[string]any{} + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + delta["reasoning_content"] = p.Text + } + } else { + s.text.WriteString(p.Text) + if !s.bufferToolContent { + delta["content"] = p.Text + } else { + events := processToolSieveChunk(&s.toolSieve, p.Text, s.toolNames) + for _, evt := range events { + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + s.toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)) + continue + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)) + continue + } + if evt.Content != "" { + contentDelta := map[string]any{ + "content": evt.Content, + } + if !s.firstChunkSent { + contentDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, contentDelta)) + } + } + } + } + if len(delta) > 0 { + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, delta)) + } + } + + if len(newChoices) > 0 { + s.sendChunk(openaifmt.BuildChatStreamChunk(s.completionID, s.created, s.model, newChoices, nil)) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/deps.go b/internal/adapter/openai/deps.go new file mode 100644 index 0000000..6688756 --- /dev/null +++ b/internal/adapter/openai/deps.go @@ -0,0 +1,35 @@ +package openai + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + DetermineCaller(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ModelAliases() map[string]string + CompatWideInputStrictOutput() bool + ToolcallMode() string + ToolcallEarlyEmitConfidence() string + ResponsesStoreTTLSeconds() int + EmbeddingsProvider() string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/openai/deps_injection_test.go b/internal/adapter/openai/deps_injection_test.go new file mode 100644 index 0000000..baa0c11 --- /dev/null +++ b/internal/adapter/openai/deps_injection_test.go @@ -0,0 +1,70 @@ +package openai + +import "testing" + +type mockOpenAIConfig struct { + aliases map[string]string + wideInput bool + toolMode string + earlyEmit string + responsesTTL int + embedProv string +} + +func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases } +func (m mockOpenAIConfig) CompatWideInputStrictOutput() bool { + return m.wideInput +} +func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMode } +func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit } +func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL } +func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv } + +func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) { + cfg := mockOpenAIConfig{ + aliases: map[string]string{ + "my-model": "deepseek-chat-search", + }, + wideInput: true, + } + req := map[string]any{ + "model": "my-model", + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + } + out, err := normalizeOpenAIChatRequest(cfg, req) + if err != nil { + t.Fatalf("normalizeOpenAIChatRequest error: %v", err) + } + if out.ResolvedModel != "deepseek-chat-search" { + t.Fatalf("resolved model mismatch: got=%q", out.ResolvedModel) + } + if !out.Search || out.Thinking { + t.Fatalf("unexpected model flags: thinking=%v search=%v", out.Thinking, out.Search) + } +} + +func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.T) { + req := map[string]any{ + "model": "deepseek-chat", + "input": "hi", + } + + _, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ + aliases: map[string]string{}, + wideInput: false, + }, req) + if err == nil { + t.Fatal("expected error when wide input is disabled and only input is provided") + } + + out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ + aliases: map[string]string{}, + wideInput: true, + }, req) + if err != nil { + t.Fatalf("unexpected error when wide input is enabled: %v", err) + } + if out.Surface != "openai_responses" { + t.Fatalf("unexpected surface: %q", out.Surface) + } +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 5ef6e7b..28a451c 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -16,7 +16,9 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" + streamengine "ds2api/internal/stream" "ds2api/internal/util" ) @@ -25,9 +27,9 @@ import ( var writeJSON = util.WriteJSON type Handler struct { - Store *config.Store - Auth *auth.Resolver - DS *deepseek.Client + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller leaseMu sync.Mutex streamLeases map[string]streamLease @@ -136,7 +138,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) writeJSON(w, http.StatusOK, respBody) } @@ -158,214 +160,49 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt } created := time.Now().Unix() - firstChunkSent := false bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() - var toolSieve toolStreamSieveState - toolCallsEmitted := false - streamToolCallIDs := map[int]string{} initialType := "text" if thinkingEnabled { initialType = "thinking" } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - thinking := strings.Builder{} - text := strings.Builder{} - lastContent := time.Now() - hasContent := false - keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second) - defer keepaliveTicker.Stop() - keepaliveCountWithoutContent := 0 - sendChunk := func(v any) { - b, _ := json.Marshal(v) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - sendDone := func() { - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if canFlush { - _ = rc.Flush() - } - } + streamRuntime := newChatStreamRuntime( + w, + rc, + canFlush, + completionID, + created, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + ) - finalize := func(finishReason string) { - finalThinking := thinking.String() - finalText := text.String() - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) > 0 && !toolCallsEmitted { - finishReason = "tool_calls" - delta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(detected), - } - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - sendChunk(util.BuildOpenAIChatStreamChunk( - completionID, - created, - model, - []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, - nil, - )) - } else if bufferToolContent { - for _, evt := range flushToolSieve(&toolSieve, toolNames) { - if evt.Content == "" { - continue - } - delta := map[string]any{ - "content": evt.Content, - } - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - sendChunk(util.BuildOpenAIChatStreamChunk( - completionID, - created, - model, - []map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)}, - nil, - )) - } - } - if len(detected) > 0 || toolCallsEmitted { - finishReason = "tool_calls" - } - sendChunk(util.BuildOpenAIChatStreamChunk( - completionID, - created, - model, - []map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)}, - util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText), - )) - sendDone() - } - - for { - select { - case <-r.Context().Done(): - return - case <-keepaliveTicker.C: - if !hasContent { - keepaliveCountWithoutContent++ - if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount { - finalize("stop") - return - } - } - if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second { - finalize("stop") + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendKeepAlive() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: func(reason streamengine.StopReason, _ error) { + if string(reason) == "content_filter" { + streamRuntime.finalize("content_filter") return } - if canFlush { - _, _ = w.Write([]byte(": keep-alive\n\n")) - _ = rc.Flush() - } - case parsed, ok := <-parsedLines: - if !ok { - // Ensure scanner completion is observed only after all queued - // SSE lines are drained, avoiding early finalize races. - _ = <-done - finalize("stop") - return - } - if !parsed.Parsed { - continue - } - if parsed.ContentFilter || parsed.ErrorMessage != "" { - finalize("content_filter") - return - } - if parsed.Stop { - finalize("stop") - return - } - newChoices := make([]map[string]any, 0, len(parsed.Parts)) - for _, p := range parsed.Parts { - if searchEnabled && sse.IsCitation(p.Text) { - continue - } - if p.Text == "" { - continue - } - hasContent = true - lastContent = time.Now() - keepaliveCountWithoutContent = 0 - delta := map[string]any{} - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - if p.Type == "thinking" { - if thinkingEnabled { - thinking.WriteString(p.Text) - delta["reasoning_content"] = p.Text - } - } else { - text.WriteString(p.Text) - if !bufferToolContent { - delta["content"] = p.Text - } else { - events := processToolSieveChunk(&toolSieve, p.Text, toolNames) - if len(events) == 0 { - // Keep thinking delta only frame. - } - for _, evt := range events { - if len(evt.ToolCallDeltas) > 0 { - if !emitEarlyToolDeltas { - continue - } - toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), - } - if !firstChunkSent { - tcDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) - continue - } - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), - } - if !firstChunkSent { - tcDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta)) - continue - } - if evt.Content != "" { - contentDelta := map[string]any{ - "content": evt.Content, - } - if !firstChunkSent { - contentDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta)) - } - } - } - } - if len(delta) > 0 { - newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta)) - } - } - if len(newChoices) > 0 { - sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil)) - } - } - } + streamRuntime.finalize("stop") + }, + }) } func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index a7bbc92..f83963f 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -1,6 +1,8 @@ package openai -import "ds2api/internal/util" +import ( + "ds2api/internal/deepseek" +) func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { messages := normalizeOpenAIMessagesForPrompt(messagesRaw) @@ -8,5 +10,5 @@ func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { messages, toolNames = injectToolPrompt(messages, tools) } - return util.MessagesPrepare(messages), toolNames + return deepseek.MessagesPrepare(messages), toolNames } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index e04fb5f..e767b2b 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -6,13 +6,16 @@ import ( "io" "net/http" "strings" + "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "ds2api/internal/auth" + "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" - "ds2api/internal/util" + streamengine "ds2api/internal/stream" ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { @@ -108,7 +111,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) - responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } @@ -127,114 +130,45 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, rc := http.NewResponseController(w) canFlush := rc.Flush() == nil - sendEvent := func(event string, payload map[string]any) { - b, _ := json.Marshal(payload) - _, _ = w.Write([]byte("event: " + event + "\n")) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - - sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model)) - initialType := "text" if thinkingEnabled { initialType = "thinking" } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() - var sieve toolStreamSieveState - thinking := strings.Builder{} - text := strings.Builder{} - toolCallsEmitted := false - streamToolCallIDs := map[int]string{} - finalize := func() { - finalThinking := thinking.String() - finalText := text.String() - if bufferToolContent { - for _, evt := range flushToolSieve(&sieve, toolNames) { - if evt.Content != "" { - sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) - } - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) - } - } - } - obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames) - if toolCallsEmitted { - obj["status"] = "completed" - } - h.getResponseStore().put(owner, responseID, obj) - sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj)) - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if canFlush { - _ = rc.Flush() - } - } + streamRuntime := newResponsesStreamRuntime( + w, + rc, + canFlush, + responseID, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + func(obj map[string]any) { + h.getResponseStore().put(owner, responseID, obj) + }, + ) + streamRuntime.sendCreated() - for { - select { - case <-r.Context().Done(): - return - case parsed, ok := <-parsedLines: - if !ok { - _ = <-done - finalize() - return - } - if !parsed.Parsed { - continue - } - if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { - finalize() - return - } - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { - continue - } - if p.Type == "thinking" { - if !thinkingEnabled { - continue - } - thinking.WriteString(p.Text) - sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text)) - continue - } - text.WriteString(p.Text) - if !bufferToolContent { - sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text)) - continue - } - for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) { - if evt.Content != "" { - sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content)) - } - if len(evt.ToolCallDeltas) > 0 { - if !emitEarlyToolDeltas { - continue - } - toolCallsEmitted = true - sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs))) - } - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) - } - } - } - } - } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: streamRuntime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + streamRuntime.finalize() + }, + }) } func responsesMessagesFromRequest(req map[string]any) []any { diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go new file mode 100644 index 0000000..f7e8b20 --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime.go @@ -0,0 +1,168 @@ +package openai + +import ( + "encoding/json" + "net/http" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type responsesStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + responseID string + model string + finalPrompt string + toolNames []string + + thinkingEnabled bool + searchEnabled bool + + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + + sieve toolStreamSieveState + thinking strings.Builder + text strings.Builder + streamToolCallIDs map[int]string + + persistResponse func(obj map[string]any) +} + +func newResponsesStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + responseID string, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, + persistResponse func(obj map[string]any), +) *responsesStreamRuntime { + return &responsesStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + responseID: responseID, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + toolNames: toolNames, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + persistResponse: persistResponse, + } +} + +func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("event: " + event + "\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) sendCreated() { + s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model)) +} + +func (s *responsesStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + if s.bufferToolContent { + for _, evt := range flushToolSieve(&s.sieve, s.toolNames) { + if evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + } + } + } + + obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) + if s.toolCallsEmitted { + obj["status"] = "completed" + } + if s.persistResponse != nil { + s.persistResponse(obj) + } + s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj)) + s.sendDone() +} + +func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) + continue + } + + s.text.WriteString(p.Text) + if !s.bufferToolContent { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) + continue + } + for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) { + if evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs))) + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + } + } + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index 52344d4..5883d03 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -8,7 +8,7 @@ import ( "ds2api/internal/util" ) -func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { @@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.S }, nil } -func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { model, _ := req["model"].(string) model = strings.TrimSpace(model) if model == "" { diff --git a/internal/admin/deps.go b/internal/admin/deps.go new file mode 100644 index 0000000..e92c37b --- /dev/null +++ b/internal/admin/deps.go @@ -0,0 +1,46 @@ +package admin + +import ( + "context" + "net/http" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type ConfigStore interface { + Snapshot() config.Config + Keys() []string + Accounts() []config.Account + FindAccount(identifier string) (config.Account, bool) + UpdateAccountToken(identifier, token string) error + Update(mutator func(*config.Config) error) error + ExportJSONAndBase64() (string, string, error) + IsEnvBacked() bool + SetVercelSync(hash string, ts int64) error + AdminPasswordHash() string + AdminJWTExpireHours() int + AdminJWTValidAfterUnix() int64 + RuntimeAccountMaxInflight() int + RuntimeAccountMaxQueue(defaultSize int) int + RuntimeGlobalMaxInflight(defaultSize int) int +} + +type PoolController interface { + Reset() + Status() map[string]any + ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) +} + +type DeepSeekCaller interface { + Login(ctx context.Context, acc config.Account) (string, error) + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +var _ ConfigStore = (*config.Store)(nil) +var _ PoolController = (*account.Pool)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 9d6151e..829b657 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -2,16 +2,12 @@ package admin import ( "github.com/go-chi/chi/v5" - - "ds2api/internal/account" - "ds2api/internal/config" - "ds2api/internal/deepseek" ) type Handler struct { - Store *config.Store - Pool *account.Pool - DS *deepseek.Client + Store ConfigStore + Pool PoolController + DS DeepSeekCaller } func RegisterRoutes(r chi.Router, h *Handler) { @@ -22,6 +18,11 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Get("/vercel/config", h.getVercelConfig) pr.Get("/config", h.getConfig) pr.Post("/config", h.updateConfig) + pr.Get("/settings", h.getSettings) + pr.Put("/settings", h.updateSettings) + pr.Post("/settings/password", h.updateSettingsPassword) + pr.Post("/config/import", h.configImport) + pr.Get("/config/export", h.configExport) pr.Post("/keys", h.addKey) pr.Delete("/keys/{key}", h.deleteKey) pr.Get("/accounts", h.listAccounts) diff --git a/internal/admin/handler_auth.go b/internal/admin/handler_auth.go index 0d3ec1f..9b96b2f 100644 --- a/internal/admin/handler_auth.go +++ b/internal/admin/handler_auth.go @@ -12,7 +12,7 @@ import ( func (h *Handler) requireAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := authn.VerifyAdminRequest(r); err != nil { + if err := authn.VerifyAdminRequestWithStore(r, h.Store); err != nil { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) return } @@ -25,18 +25,18 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) { _ = json.NewDecoder(r.Body).Decode(&req) adminKey, _ := req["admin_key"].(string) expireHours := intFrom(req["expire_hours"]) - if expireHours <= 0 { - expireHours = 24 - } - if adminKey != authn.AdminKey() { + if !authn.VerifyAdminCredential(adminKey, h.Store) { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"}) return } - token, err := authn.CreateJWT(expireHours) + token, err := authn.CreateJWTWithStore(expireHours, h.Store) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) return } + if expireHours <= 0 { + expireHours = h.Store.AdminJWTExpireHours() + } writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600}) } @@ -47,7 +47,7 @@ func (h *Handler) verify(w http.ResponseWriter, r *http.Request) { return } token := strings.TrimSpace(header[7:]) - payload, err := authn.VerifyJWT(token) + payload, err := authn.VerifyJWTWithStore(token, h.Store) if err != nil { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) return diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go index 2b672c3..dfbd005 100644 --- a/internal/admin/handler_config.go +++ b/internal/admin/handler_config.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "sort" "strings" "github.com/go-chi/chi/v5" @@ -204,38 +203,191 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { } func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { + h.configExport(w, nil) +} + +func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() jsonStr, b64, err := h.Store.ExportJSONAndBase64() if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) return } - writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64}) + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "config": snap, + "json": jsonStr, + "base64": b64, + }) +} + +func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode"))) + if mode == "" { + mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode"))) + } + if mode == "" { + mode = "merge" + } + if mode != "merge" && mode != "replace" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"}) + return + } + + payload := req + if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 { + payload = raw + } + rawJSON, err := json.Marshal(payload) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"}) + return + } + var incoming config.Config + if err := json.Unmarshal(rawJSON, &incoming); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + + importedKeys, importedAccounts := 0, 0 + err = h.Store.Update(func(c *config.Config) error { + next := c.Clone() + if mode == "replace" { + next = incoming.Clone() + next.VercelSyncHash = c.VercelSyncHash + next.VercelSyncTime = c.VercelSyncTime + importedKeys = len(next.Keys) + importedAccounts = len(next.Accounts) + } else { + existingKeys := map[string]struct{}{} + for _, k := range next.Keys { + existingKeys[k] = struct{}{} + } + for _, k := range incoming.Keys { + key := strings.TrimSpace(k) + if key == "" { + continue + } + if _, ok := existingKeys[key]; ok { + continue + } + existingKeys[key] = struct{}{} + next.Keys = append(next.Keys, key) + importedKeys++ + } + + existingAccounts := map[string]struct{}{} + for _, acc := range next.Accounts { + existingAccounts[acc.Identifier()] = struct{}{} + } + for _, acc := range incoming.Accounts { + id := acc.Identifier() + if id == "" { + continue + } + if _, ok := existingAccounts[id]; ok { + continue + } + existingAccounts[id] = struct{}{} + next.Accounts = append(next.Accounts, acc) + importedAccounts++ + } + + if len(incoming.ClaudeMapping) > 0 { + if next.ClaudeMapping == nil { + next.ClaudeMapping = map[string]string{} + } + for k, v := range incoming.ClaudeMapping { + next.ClaudeMapping[k] = v + } + } + if len(incoming.ClaudeModelMap) > 0 { + if next.ClaudeModelMap == nil { + next.ClaudeModelMap = map[string]string{} + } + for k, v := range incoming.ClaudeModelMap { + next.ClaudeModelMap[k] = v + } + } + + if len(incoming.ModelAliases) > 0 { + if next.ModelAliases == nil { + next.ModelAliases = map[string]string{} + } + for k, v := range incoming.ModelAliases { + next.ModelAliases[k] = v + } + } + if strings.TrimSpace(incoming.Toolcall.Mode) != "" { + next.Toolcall.Mode = incoming.Toolcall.Mode + } + if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" { + next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence + } + if incoming.Responses.StoreTTLSeconds > 0 { + next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds + } + if strings.TrimSpace(incoming.Embeddings.Provider) != "" { + next.Embeddings.Provider = incoming.Embeddings.Provider + } + if strings.TrimSpace(incoming.Admin.PasswordHash) != "" { + next.Admin.PasswordHash = incoming.Admin.PasswordHash + } + if incoming.Admin.JWTExpireHours > 0 { + next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours + } + if incoming.Admin.JWTValidAfterUnix > 0 { + next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix + } + if incoming.Runtime.AccountMaxInflight > 0 { + next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight + } + if incoming.Runtime.AccountMaxQueue > 0 { + next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue + } + if incoming.Runtime.GlobalMaxInflight > 0 { + next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight + } + } + + normalizeSettingsConfig(&next) + if err := validateSettingsConfig(next); err != nil { + return newRequestError(err.Error()) + } + + *c = next + return nil + }) + if err != nil { + if detail, ok := requestErrorDetail(err); ok { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail}) + return + } + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "mode": mode, + "imported_keys": importedKeys, + "imported_accounts": importedAccounts, + "message": "config imported", + }) } func (h *Handler) computeSyncHash() string { - snap := h.Store.Snapshot() - syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}} - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, a := range snap.Accounts { - m := map[string]any{} - if a.Email != "" { - m["email"] = a.Email - } - if a.Mobile != "" { - m["mobile"] = a.Mobile - } - if a.Password != "" { - m["password"] = a.Password - } - accounts = append(accounts, m) - } - sort.Slice(accounts, func(i, j int) bool { - ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"]) - aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"]) - return ai < aj - }) - syncable["accounts"] = accounts - b, _ := json.Marshal(syncable) + snap := h.Store.Snapshot().Clone() + snap.VercelSyncHash = "" + snap.VercelSyncTime = 0 + b, _ := json.Marshal(snap) sum := md5.Sum(b) return fmt.Sprintf("%x", sum) } diff --git a/internal/admin/handler_settings.go b/internal/admin/handler_settings.go new file mode 100644 index 0000000..06c234c --- /dev/null +++ b/internal/admin/handler_settings.go @@ -0,0 +1,321 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + authn "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight()) + needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "admin": map[string]any{ + "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "", + "jwt_expire_hours": h.Store.AdminJWTExpireHours(), + "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix, + "default_password_warning": authn.UsingDefaultAdminKey(h.Store), + }, + "runtime": map[string]any{ + "account_max_inflight": h.Store.RuntimeAccountMaxInflight(), + "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended), + "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended), + }, + "toolcall": snap.Toolcall, + "responses": snap.Responses, + "embeddings": snap.Embeddings, + "claude_mapping": settingsClaudeMapping(snap), + "model_aliases": snap.ModelAliases, + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + }) +} + +func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + if runtimeCfg != nil { + if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + } + + if err := h.Store.Update(func(c *config.Config) error { + if adminCfg != nil { + if adminCfg.JWTExpireHours > 0 { + c.Admin.JWTExpireHours = adminCfg.JWTExpireHours + } + } + if runtimeCfg != nil { + if runtimeCfg.AccountMaxInflight > 0 { + c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight + } + if runtimeCfg.AccountMaxQueue > 0 { + c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue + } + if runtimeCfg.GlobalMaxInflight > 0 { + c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight + } + } + if toolcallCfg != nil { + if strings.TrimSpace(toolcallCfg.Mode) != "" { + c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode) + } + if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" { + c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) + } + } + if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 { + c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds + } + if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" { + c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider) + } + if claudeMap != nil { + c.ClaudeMapping = claudeMap + c.ClaudeModelMap = nil + } + if aliasMap != nil { + c.ModelAliases = aliasMap + } + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.applyRuntimeSettings() + needsSync := config.IsVercel() || h.Store.IsEnvBacked() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "settings updated and hot reloaded", + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。", + }) +} + +func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error { + merged := current + if incoming != nil { + if incoming.AccountMaxInflight > 0 { + merged.AccountMaxInflight = incoming.AccountMaxInflight + } + if incoming.AccountMaxQueue > 0 { + merged.AccountMaxQueue = incoming.AccountMaxQueue + } + if incoming.GlobalMaxInflight > 0 { + merged.GlobalMaxInflight = incoming.GlobalMaxInflight + } + } + return validateRuntimeSettings(merged) +} + +func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + newPassword := strings.TrimSpace(fieldString(req, "new_password")) + if newPassword == "" { + newPassword = strings.TrimSpace(fieldString(req, "password")) + } + if len(newPassword) < 4 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"}) + return + } + + now := time.Now().Unix() + hash := authn.HashAdminPassword(newPassword) + if err := h.Store.Update(func(c *config.Config) error { + c.Admin.PasswordHash = hash + c.Admin.JWTValidAfterUnix = now + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "password updated", + "force_relogin": true, + "jwt_valid_after_unix": now, + }) +} + +func (h *Handler) applyRuntimeSettings() { + if h == nil || h.Store == nil || h.Pool == nil { + return + } + accountCount := len(h.Store.Accounts()) + maxPer := h.Store.RuntimeAccountMaxInflight() + recommended := defaultRuntimeRecommended(accountCount, maxPer) + maxQueue := h.Store.RuntimeAccountMaxQueue(recommended) + global := h.Store.RuntimeGlobalMaxInflight(recommended) + h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global) +} + +func defaultRuntimeRecommended(accountCount, maxPer int) int { + if maxPer <= 0 { + maxPer = 1 + } + if accountCount <= 0 { + return maxPer + } + return accountCount * maxPer +} + +func settingsClaudeMapping(c config.Config) map[string]string { + if len(c.ClaudeMapping) > 0 { + return c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + return c.ClaudeModelMap + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +} + +func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) { + var ( + adminCfg *config.AdminConfig + runtimeCfg *config.RuntimeConfig + toolcallCfg *config.ToolcallConfig + respCfg *config.ResponsesConfig + embCfg *config.EmbeddingsConfig + claudeMap map[string]string + aliasMap map[string]string + ) + + if raw, ok := req["admin"].(map[string]any); ok { + cfg := &config.AdminConfig{} + if v, exists := raw["jwt_expire_hours"]; exists { + n := intFrom(v) + if n < 1 || n > 720 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + cfg.JWTExpireHours = n + } + adminCfg = cfg + } + + if raw, ok := req["runtime"].(map[string]any); ok { + cfg := &config.RuntimeConfig{} + if v, exists := raw["account_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 256 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + cfg.AccountMaxInflight = n + } + if v, exists := raw["account_max_queue"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + cfg.AccountMaxQueue = n + } + if v, exists := raw["global_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + cfg.GlobalMaxInflight = n + } + if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + runtimeCfg = cfg + } + + if raw, ok := req["toolcall"].(map[string]any); ok { + cfg := &config.ToolcallConfig{} + if v, exists := raw["mode"]; exists { + mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch mode { + case "feature_match", "off": + cfg.Mode = mode + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if v, exists := raw["early_emit_confidence"]; exists { + level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch level { + case "high", "low", "off": + cfg.EarlyEmitConfidence = level + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + toolcallCfg = cfg + } + + if raw, ok := req["responses"].(map[string]any); ok { + cfg := &config.ResponsesConfig{} + if v, exists := raw["store_ttl_seconds"]; exists { + n := intFrom(v) + if n < 30 || n > 86400 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + cfg.StoreTTLSeconds = n + } + respCfg = cfg + } + + if raw, ok := req["embeddings"].(map[string]any); ok { + cfg := &config.EmbeddingsConfig{} + if v, exists := raw["provider"]; exists { + p := strings.TrimSpace(fmt.Sprintf("%v", v)) + if p == "" { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty") + } + cfg.Provider = p + } + embCfg = cfg + } + + if raw, ok := req["claude_mapping"].(map[string]any); ok { + claudeMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + claudeMap[key] = val + } + } + + if raw, ok := req["model_aliases"].(map[string]any); ok { + aliasMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + aliasMap[key] = val + } + } + + return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil +} diff --git a/internal/admin/handler_settings_test.go b/internal/admin/handler_settings_test.go new file mode 100644 index 0000000..3eb5114 --- /dev/null +++ b/internal/admin/handler_settings_test.go @@ -0,0 +1,267 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + authn "ds2api/internal/auth" +) + +func TestGetSettingsDefaultPasswordWarning(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "") + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil) + rec := httptest.NewRecorder() + h.getSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &body) + admin, _ := body["admin"].(map[string]any) + warn, _ := admin["default_password_warning"].(bool) + if !warn { + t.Fatalf("expected default password warning true, body=%v", body) + } +} + +func TestUpdateSettingsValidation(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 0, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } +} + +func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":8 + } + }`) + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 16, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) { + t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String()) + } +} + +func TestUpdateSettingsWithoutRuntimeSkipsMergedRuntimeValidation(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":4 + } + }`) + payload := map[string]any{ + "responses": map[string]any{ + "store_ttl_seconds": 600, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if got := h.Store.Snapshot().Responses.StoreTTLSeconds; got != 600 { + t.Fatalf("store_ttl_seconds=%d want=600", got) + } +} + +func TestUpdateSettingsHotReloadRuntime(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"email":"a@test.com","token":"t1"},{"email":"b@test.com","token":"t2"}] + }`) + + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 3, + "account_max_queue": 20, + "global_max_inflight": 5, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + status := h.Pool.Status() + if got := intFrom(status["max_inflight_per_account"]); got != 3 { + t.Fatalf("max_inflight_per_account=%d want=3", got) + } + if got := intFrom(status["max_queue_size"]); got != 20 { + t.Fatalf("max_queue_size=%d want=20", got) + } + if got := intFrom(status["global_max_inflight"]); got != 5 { + t.Fatalf("global_max_inflight=%d want=5", got) + } +} + +func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) { + hash := authn.HashAdminPassword("old-password") + h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`) + + token, err := authn.CreateJWTWithStore(1, h.Store) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + if _, err := authn.VerifyJWTWithStore(token, h.Store); err != nil { + t.Fatalf("verify before update failed: %v", err) + } + + body := map[string]any{"new_password": "new-password"} + b, _ := json.Marshal(body) + req := httptest.NewRequest(http.MethodPost, "/admin/settings/password", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettingsPassword(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + + if _, err := authn.VerifyJWTWithStore(token, h.Store); err == nil { + t.Fatal("expected old token to be invalid after password update") + } + if !authn.VerifyAdminCredential("new-password", h.Store) { + t.Fatal("expected new password credential to be accepted") + } +} + +func TestConfigImportMergeAndReplace(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"email":"a@test.com","password":"p1"}] + }`) + + merge := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "keys": []any{"k1", "k2"}, + "accounts": []any{ + map[string]any{"email": "a@test.com", "password": "p1"}, + map[string]any{"email": "b@test.com", "password": "p2"}, + }, + }, + } + mergeBytes, _ := json.Marshal(merge) + mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes)) + mergeRec := httptest.NewRecorder() + h.configImport(mergeRec, mergeReq) + if mergeRec.Code != http.StatusOK { + t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String()) + } + if got := len(h.Store.Keys()); got != 2 { + t.Fatalf("keys after merge=%d want=2", got) + } + if got := len(h.Store.Accounts()); got != 2 { + t.Fatalf("accounts after merge=%d want=2", got) + } + + replace := map[string]any{ + "mode": "replace", + "config": map[string]any{ + "keys": []any{"k9"}, + }, + } + replaceBytes, _ := json.Marshal(replace) + replaceReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(replaceBytes)) + replaceRec := httptest.NewRecorder() + h.configImport(replaceRec, replaceReq) + if replaceRec.Code != http.StatusOK { + t.Fatalf("replace status=%d body=%s", replaceRec.Code, replaceRec.Body.String()) + } + keys := h.Store.Keys() + if len(keys) != 1 || keys[0] != "k9" { + t.Fatalf("unexpected keys after replace: %#v", keys) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("accounts after replace=%d want=0", got) + } +} + +func TestConfigImportRejectsInvalidRuntimeBounds(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "mode": "replace", + "config": map[string]any{ + "keys": []any{"k2"}, + "runtime": map[string]any{ + "account_max_inflight": 300, + }, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.account_max_inflight")) { + t.Fatalf("expected runtime bound detail, got %s", rec.Body.String()) + } + keys := h.Store.Keys() + if len(keys) != 1 || keys[0] != "k1" { + t.Fatalf("store should remain unchanged, keys=%v", keys) + } +} + +func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":8 + } + }`) + payload := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 16, + }, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) { + t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String()) + } + snap := h.Store.Snapshot() + if snap.Runtime.AccountMaxInflight != 8 || snap.Runtime.GlobalMaxInflight != 8 { + t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime) + } +} diff --git a/internal/admin/handler_vercel.go b/internal/admin/handler_vercel.go index 189d8cc..2c6356c 100644 --- a/internal/admin/handler_vercel.go +++ b/internal/admin/handler_vercel.go @@ -3,8 +3,8 @@ package admin import ( "bytes" "context" - "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -19,6 +19,62 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) return } + opts, err := parseVercelSyncOptions(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + validated, failed := h.validateAccountsForVercelSync(r.Context(), opts.AutoValidate) + _, cfgB64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + client := &http.Client{Timeout: 30 * time.Second} + params := buildVercelParams(opts.TeamID) + headers := map[string]string{"Authorization": "Bearer " + opts.VercelToken} + + envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+opts.ProjectID+"/env", params, headers, nil) + if err != nil || status != http.StatusOK { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) + return + } + envs, _ := envResp["envs"].([]any) + status, err = upsertVercelEnv(r.Context(), client, opts.ProjectID, params, headers, envs, "DS2API_CONFIG_JSON", cfgB64) + if err != nil || (status != http.StatusOK && status != http.StatusCreated) { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) + return + } + savedCreds := h.saveVercelProjectCredentials(r.Context(), client, opts, params, headers, envs) + manual, deployURL := triggerVercelDeployment(r.Context(), client, opts.ProjectID, params, headers) + _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) + result := map[string]any{"success": true, "validated_accounts": validated} + if manual { + result["message"] = "配置已同步到 Vercel,请手动触发重新部署" + result["manual_deploy_required"] = true + } else { + result["message"] = "配置已同步,正在重新部署..." + result["deployment_url"] = deployURL + } + if len(failed) > 0 { + result["failed_accounts"] = failed + } + if len(savedCreds) > 0 { + result["saved_credentials"] = savedCreds + } + writeJSON(w, http.StatusOK, result) +} + +type vercelSyncOptions struct { + VercelToken string + ProjectID string + TeamID string + AutoValidate bool + SaveCreds bool + UsePreconfig bool +} + +func parseVercelSyncOptions(req map[string]any) (vercelSyncOptions, error) { vercelToken, _ := req["vercel_token"].(string) projectID, _ := req["project_id"].(string) teamID, _ := req["team_id"].(string) @@ -40,108 +96,117 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(teamID) == "" { teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID")) } + vercelToken = strings.TrimSpace(vercelToken) + projectID = strings.TrimSpace(projectID) + teamID = strings.TrimSpace(teamID) if vercelToken == "" || projectID == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"}) - return + return vercelSyncOptions{}, fmt.Errorf("需要 Vercel Token 和 Project ID") + } + return vercelSyncOptions{ + VercelToken: vercelToken, + ProjectID: projectID, + TeamID: teamID, + AutoValidate: autoValidate, + SaveCreds: saveCreds, + UsePreconfig: usePreconfig, + }, nil +} + +func buildVercelParams(teamID string) url.Values { + params := url.Values{} + if strings.TrimSpace(teamID) != "" { + params.Set("teamId", strings.TrimSpace(teamID)) + } + return params +} + +func (h *Handler) validateAccountsForVercelSync(ctx context.Context, enabled bool) (int, []string) { + if !enabled { + return 0, nil } validated, failed := 0, []string{} - if autoValidate { - for _, acc := range h.Store.Snapshot().Accounts { - if strings.TrimSpace(acc.Token) != "" { - continue - } - token, err := h.DS.Login(r.Context(), acc) - if err != nil { - failed = append(failed, acc.Identifier()) - } else { - validated++ - _ = h.Store.UpdateAccountToken(acc.Identifier(), token) - } - time.Sleep(500 * time.Millisecond) + for _, acc := range h.Store.Snapshot().Accounts { + if strings.TrimSpace(acc.Token) != "" { + continue } + token, err := h.DS.Login(ctx, acc) + if err != nil { + failed = append(failed, acc.Identifier()) + } else { + validated++ + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + time.Sleep(500 * time.Millisecond) } + return validated, failed +} - cfgJSON, _, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return +func upsertVercelEnv(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string, envs []any, key, value string) (int, error) { + existingID := findEnvID(envs, key) + if existingID != "" { + _, status, err := vercelRequest(ctx, client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingID, params, headers, map[string]any{"value": value}) + return status, err } - cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON)) - client := &http.Client{Timeout: 30 * time.Second} - params := url.Values{} - if teamID != "" { - params.Set("teamId", teamID) + _, status, err := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{ + "key": key, + "value": value, + "type": "encrypted", + "target": []string{"production", "preview"}, + }) + return status, err +} + +func (h *Handler) saveVercelProjectCredentials(ctx context.Context, client *http.Client, opts vercelSyncOptions, params url.Values, headers map[string]string, envs []any) []string { + if !opts.SaveCreds || opts.UsePreconfig { + return nil } - headers := map[string]string{"Authorization": "Bearer " + vercelToken} - envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil) - if err != nil || status != http.StatusOK { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) - return + saved := []string{} + creds := [][2]string{{"VERCEL_TOKEN", opts.VercelToken}, {"VERCEL_PROJECT_ID", opts.ProjectID}} + if opts.TeamID != "" { + creds = append(creds, [2]string{"VERCEL_TEAM_ID", opts.TeamID}) } - envs, _ := envResp["envs"].([]any) - existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON") - if existingEnvID != "" { - _, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64}) - } else { - _, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}}) - } - if err != nil || (status != http.StatusOK && status != http.StatusCreated) { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) - return - } - savedCreds := []string{} - if saveCreds && !usePreconfig { - creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}} - if teamID != "" { - creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID}) - } - for _, kv := range creds { - id := findEnvID(envs, kv[0]) - if id != "" { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]}) - } else { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}}) - } - if status == http.StatusOK || status == http.StatusCreated { - savedCreds = append(savedCreds, kv[0]) - } + for _, kv := range creds { + status, _ := upsertVercelEnv(ctx, client, opts.ProjectID, params, headers, envs, kv[0], kv[1]) + if status == http.StatusOK || status == http.StatusCreated { + saved = append(saved, kv[0]) } } - projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) - manual := true - deployURL := "" - if status == http.StatusOK { - if link, ok := projectResp["link"].(map[string]any); ok { - if linkType, _ := link["type"].(string); linkType == "github" { - repoID := intFrom(link["repoId"]) - ref, _ := link["productionBranch"].(string) - if ref == "" { - ref = "main" - } - depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}}) - if depStatus == http.StatusOK || depStatus == http.StatusCreated { - deployURL, _ = depResp["url"].(string) - manual = false - } - } - } + return saved +} + +func triggerVercelDeployment(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string) (bool, string) { + projectResp, status, _ := vercelRequest(ctx, client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) + if status != http.StatusOK { + return true, "" } - _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) - result := map[string]any{"success": true, "validated_accounts": validated} - if manual { - result["message"] = "配置已同步到 Vercel,请手动触发重新部署" - result["manual_deploy_required"] = true - } else { - result["message"] = "配置已同步,正在重新部署..." - result["deployment_url"] = deployURL + link, ok := projectResp["link"].(map[string]any) + if !ok { + return true, "" } - if len(failed) > 0 { - result["failed_accounts"] = failed + linkType, _ := link["type"].(string) + if linkType != "github" { + return true, "" } - if len(savedCreds) > 0 { - result["saved_credentials"] = savedCreds + repoID := intFrom(link["repoId"]) + ref, _ := link["productionBranch"].(string) + if ref == "" { + ref = "main" } - writeJSON(w, http.StatusOK, result) + depResp, depStatus, _ := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{ + "name": projectID, + "project": projectID, + "target": "production", + "gitSource": map[string]any{ + "type": "github", + "repoId": repoID, + "ref": ref, + }, + }) + if depStatus != http.StatusOK && depStatus != http.StatusCreated { + return true, "" + } + deployURL, _ := depResp["url"].(string) + return false, deployURL } func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) { diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index d7d1198..2e00323 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -96,7 +96,7 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool { return acc.Identifier() == id } -func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) { +func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) { id := strings.TrimSpace(identifier) if id == "" { return config.Account{}, false diff --git a/internal/admin/request_error.go b/internal/admin/request_error.go new file mode 100644 index 0000000..5431a3d --- /dev/null +++ b/internal/admin/request_error.go @@ -0,0 +1,23 @@ +package admin + +import "errors" + +type requestError struct { + detail string +} + +func (e *requestError) Error() string { + return e.detail +} + +func newRequestError(detail string) error { + return &requestError{detail: detail} +} + +func requestErrorDetail(err error) (string, bool) { + var reqErr *requestError + if errors.As(err, &reqErr) { + return reqErr.detail, true + } + return "", false +} diff --git a/internal/admin/settings_validation.go b/internal/admin/settings_validation.go new file mode 100644 index 0000000..f9d4c2f --- /dev/null +++ b/internal/admin/settings_validation.go @@ -0,0 +1,64 @@ +package admin + +import ( + "fmt" + "strings" + + "ds2api/internal/config" +) + +func normalizeSettingsConfig(c *config.Config) { + if c == nil { + return + } + c.Admin.PasswordHash = strings.TrimSpace(c.Admin.PasswordHash) + c.Toolcall.Mode = strings.ToLower(strings.TrimSpace(c.Toolcall.Mode)) + c.Toolcall.EarlyEmitConfidence = strings.ToLower(strings.TrimSpace(c.Toolcall.EarlyEmitConfidence)) + c.Embeddings.Provider = strings.TrimSpace(c.Embeddings.Provider) +} + +func validateSettingsConfig(c config.Config) error { + if c.Admin.JWTExpireHours != 0 && (c.Admin.JWTExpireHours < 1 || c.Admin.JWTExpireHours > 720) { + return fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + if err := validateRuntimeSettings(c.Runtime); err != nil { + return err + } + if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) { + return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + if mode := strings.TrimSpace(c.Toolcall.Mode); mode != "" { + switch mode { + case "feature_match", "off": + default: + return fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if level := strings.TrimSpace(c.Toolcall.EarlyEmitConfidence); level != "" { + switch level { + case "high", "low", "off": + default: + return fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" { + return fmt.Errorf("embeddings.provider cannot be empty") + } + return nil +} + +func validateRuntimeSettings(runtime config.RuntimeConfig) error { + if runtime.AccountMaxInflight != 0 && (runtime.AccountMaxInflight < 1 || runtime.AccountMaxInflight > 256) { + return fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + if runtime.AccountMaxQueue != 0 && (runtime.AccountMaxQueue < 1 || runtime.AccountMaxQueue > 200000) { + return fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) { + return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight { + return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + return nil +} diff --git a/internal/auth/admin.go b/internal/auth/admin.go index 3a52f6c..8f1d276 100644 --- a/internal/auth/admin.go +++ b/internal/auth/admin.go @@ -3,7 +3,9 @@ package auth import ( "crypto/hmac" "crypto/sha256" + "crypto/subtle" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "log/slog" @@ -17,7 +19,22 @@ import ( var warnOnce sync.Once +type AdminConfigReader interface { + AdminPasswordHash() string + AdminJWTExpireHours() int + AdminJWTValidAfterUnix() int64 +} + func AdminKey() string { + return effectiveAdminKey(nil) +} + +func effectiveAdminKey(store AdminConfigReader) string { + if store != nil { + if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" { + return "" + } + } if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" { return v } @@ -27,14 +44,24 @@ func AdminKey() string { return "admin" } -func jwtSecret() string { +func jwtSecret(store AdminConfigReader) string { if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" { return v } - return AdminKey() + if store != nil { + if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" { + return hash + } + } + return effectiveAdminKey(store) } -func jwtExpireHours() int { +func jwtExpireHours(store AdminConfigReader) int { + if store != nil { + if n := store.AdminJWTExpireHours(); n > 0 { + return n + } + } if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { return n @@ -44,27 +71,44 @@ func jwtExpireHours() int { } func CreateJWT(expireHours int) (string, error) { + return CreateJWTWithStore(expireHours, nil) +} + +func CreateJWTWithStore(expireHours int, store AdminConfigReader) (string, error) { if expireHours <= 0 { - expireHours = jwtExpireHours() + expireHours = jwtExpireHours(store) } + issuedAt := time.Now().Unix() + // If sessions were invalidated in this same second, move iat forward by + // one second so newly minted tokens remain valid with strict cutoff checks. + if store != nil { + if validAfter := store.AdminJWTValidAfterUnix(); validAfter >= issuedAt { + issuedAt = validAfter + 1 + } + } + expireAt := time.Unix(issuedAt, 0).Add(time.Duration(expireHours) * time.Hour).Unix() header := map[string]any{"alg": "HS256", "typ": "JWT"} - payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"} + payload := map[string]any{"iat": issuedAt, "exp": expireAt, "role": "admin"} h, _ := json.Marshal(header) p, _ := json.Marshal(payload) headerB64 := rawB64Encode(h) payloadB64 := rawB64Encode(p) msg := headerB64 + "." + payloadB64 - sig := signHS256(msg) + sig := signHS256(msg, store) return msg + "." + rawB64Encode(sig), nil } func VerifyJWT(token string) (map[string]any, error) { + return VerifyJWTWithStore(token, nil) +} + +func VerifyJWTWithStore(token string, store AdminConfigReader) (map[string]any, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return nil, errors.New("invalid token format") } msg := parts[0] + "." + parts[1] - expected := signHS256(msg) + expected := signHS256(msg, store) actual, err := rawB64Decode(parts[2]) if err != nil { return nil, errors.New("invalid signature") @@ -84,10 +128,23 @@ func VerifyJWT(token string) (map[string]any, error) { if int64(exp) < time.Now().Unix() { return nil, errors.New("token expired") } + if store != nil { + validAfter := store.AdminJWTValidAfterUnix() + if validAfter > 0 { + iat, _ := payload["iat"].(float64) + if int64(iat) <= validAfter { + return nil, errors.New("token expired") + } + } + } return payload, nil } func VerifyAdminRequest(r *http.Request) error { + return VerifyAdminRequestWithStore(r, nil) +} + +func VerifyAdminRequestWithStore(r *http.Request, store AdminConfigReader) error { authHeader := strings.TrimSpace(r.Header.Get("Authorization")) if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { return errors.New("authentication required") @@ -96,17 +153,65 @@ func VerifyAdminRequest(r *http.Request) error { if token == "" { return errors.New("authentication required") } - if token == AdminKey() { + if VerifyAdminCredential(token, store) { return nil } - if _, err := VerifyJWT(token); err == nil { + if _, err := VerifyJWTWithStore(token, store); err == nil { return nil } return errors.New("invalid credentials") } -func signHS256(msg string) []byte { - h := hmac.New(sha256.New, []byte(jwtSecret())) +func VerifyAdminCredential(candidate string, store AdminConfigReader) bool { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return false + } + if store != nil { + hash := strings.TrimSpace(store.AdminPasswordHash()) + if hash != "" { + return verifyAdminPasswordHash(candidate, hash) + } + } + key := effectiveAdminKey(store) + if key == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(candidate), []byte(key)) == 1 +} + +func UsingDefaultAdminKey(store AdminConfigReader) bool { + if store != nil && strings.TrimSpace(store.AdminPasswordHash()) != "" { + return false + } + return strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) == "" +} + +func HashAdminPassword(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + sum := sha256.Sum256([]byte(raw)) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +func verifyAdminPasswordHash(candidate, encoded string) bool { + encoded = strings.TrimSpace(strings.ToLower(encoded)) + if encoded == "" { + return false + } + if strings.HasPrefix(encoded, "sha256:") { + want := strings.TrimPrefix(encoded, "sha256:") + sum := sha256.Sum256([]byte(candidate)) + got := hex.EncodeToString(sum[:]) + return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1 + } + return subtle.ConstantTimeCompare([]byte(candidate), []byte(encoded)) == 1 +} + +func signHS256(msg string, store AdminConfigReader) []byte { + h := hmac.New(sha256.New, []byte(jwtSecret(store))) _, _ = h.Write([]byte(msg)) return h.Sum(nil) } diff --git a/internal/auth/admin_test.go b/internal/auth/admin_test.go index 7489074..bfbd4c3 100644 --- a/internal/auth/admin_test.go +++ b/internal/auth/admin_test.go @@ -3,6 +3,8 @@ package auth import ( "net/http" "testing" + + "ds2api/internal/config" ) func TestJWTCreateVerify(t *testing.T) { @@ -27,3 +29,58 @@ func TestVerifyAdminRequest(t *testing.T) { t.Fatalf("expected token accepted: %v", err) } } + +func TestVerifyJWTWithStoreValidAfter(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`) + store := config.LoadStore() + token, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + if _, err := VerifyJWTWithStore(token, store); err != nil { + t.Fatalf("verify before invalidation failed: %v", err) + } + if err := store.Update(func(c *config.Config) error { + c.Admin.JWTValidAfterUnix = 1<<62 - 1 + return nil + }); err != nil { + t.Fatalf("set valid-after failed: %v", err) + } + if _, err := VerifyJWTWithStore(token, store); err == nil { + t.Fatal("expected token invalid after valid-after update") + } +} + +func TestVerifyJWTWithStoreSameSecondInvalidationAndRelogin(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`) + store := config.LoadStore() + + oldToken, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create old jwt failed: %v", err) + } + oldPayload, err := VerifyJWTWithStore(oldToken, store) + if err != nil { + t.Fatalf("verify old jwt before invalidation failed: %v", err) + } + oldIAT, _ := oldPayload["iat"].(float64) + + if err := store.Update(func(c *config.Config) error { + c.Admin.JWTValidAfterUnix = int64(oldIAT) + return nil + }); err != nil { + t.Fatalf("set valid-after failed: %v", err) + } + + if _, err := VerifyJWTWithStore(oldToken, store); err == nil { + t.Fatal("expected old token invalid when iat == valid-after") + } + + newToken, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create new jwt failed: %v", err) + } + if _, err := VerifyJWTWithStore(newToken, store); err != nil { + t.Fatalf("expected new token valid after invalidation cutoff: %v", err) + } +} diff --git a/internal/claudeconv/convert.go b/internal/claudeconv/convert.go new file mode 100644 index 0000000..1ce1f01 --- /dev/null +++ b/internal/claudeconv/convert.go @@ -0,0 +1,48 @@ +package claudeconv + +import "strings" + +type ClaudeMappingProvider interface { + ClaudeMapping() map[string]string +} + +func ConvertClaudeToDeepSeek(claudeReq map[string]any, mappingProvider ClaudeMappingProvider, defaultClaudeModel string) map[string]any { + messages, _ := claudeReq["messages"].([]any) + model, _ := claudeReq["model"].(string) + if model == "" { + model = defaultClaudeModel + } + + mapping := map[string]string{} + if mappingProvider != nil { + mapping = mappingProvider.ClaudeMapping() + } + dsModel := mapping["fast"] + if dsModel == "" { + dsModel = "deepseek-chat" + } + + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { + if slow := mapping["slow"]; slow != "" { + dsModel = slow + } + } + + convertedMessages := make([]any, 0, len(messages)+1) + if system, ok := claudeReq["system"].(string); ok && system != "" { + convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) + } + convertedMessages = append(convertedMessages, messages...) + + out := map[string]any{"model": dsModel, "messages": convertedMessages} + for _, k := range []string{"temperature", "top_p", "stream"} { + if v, ok := claudeReq[k]; ok { + out[k] = v + } + } + if stopSeq, ok := claudeReq["stop_sequences"]; ok { + out["stop"] = stopSeq + } + return out +} diff --git a/internal/compat/go_compat_test.go b/internal/compat/go_compat_test.go new file mode 100644 index 0000000..024e7ba --- /dev/null +++ b/internal/compat/go_compat_test.go @@ -0,0 +1,142 @@ +package compat + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func TestGoCompatSSEFixtures(t *testing.T) { + files, err := filepath.Glob(compatPath("fixtures", "sse_chunks", "*.json")) + if err != nil { + t.Fatalf("glob fixtures failed: %v", err) + } + if len(files) == 0 { + t.Fatal("no sse fixtures found") + } + for _, fixturePath := range files { + name := trimExt(filepath.Base(fixturePath)) + expectedPath := compatPath("expected", "sse_"+name+".json") + + var fixture struct { + Chunk map[string]any `json:"chunk"` + ThinkingEnable bool `json:"thinking_enabled"` + CurrentType string `json:"current_type"` + } + mustLoadJSON(t, fixturePath, &fixture) + + var expected struct { + Parts []map[string]any `json:"parts"` + Finished bool `json:"finished"` + NewType string `json:"new_type"` + } + mustLoadJSON(t, expectedPath, &expected) + + parts, finished, newType := sse.ParseSSEChunkForContent(fixture.Chunk, fixture.ThinkingEnable, fixture.CurrentType) + gotParts := make([]map[string]any, 0, len(parts)) + for _, p := range parts { + gotParts = append(gotParts, map[string]any{ + "text": p.Text, + "type": p.Type, + }) + } + if !reflect.DeepEqual(gotParts, expected.Parts) || finished != expected.Finished || newType != expected.NewType { + t.Fatalf("fixture %s mismatch:\n got parts=%#v finished=%v newType=%q\nwant parts=%#v finished=%v newType=%q", + name, gotParts, finished, newType, expected.Parts, expected.Finished, expected.NewType) + } + } +} + +func TestGoCompatToolcallFixtures(t *testing.T) { + files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json")) + if err != nil { + t.Fatalf("glob toolcall fixtures failed: %v", err) + } + if len(files) == 0 { + t.Fatal("no toolcall fixtures found") + } + for _, fixturePath := range files { + name := trimExt(filepath.Base(fixturePath)) + expectedPath := compatPath("expected", "toolcalls_"+name+".json") + + var fixture struct { + Text string `json:"text"` + ToolNames []string `json:"tool_names"` + } + mustLoadJSON(t, fixturePath, &fixture) + + var expected struct { + Calls []util.ParsedToolCall `json:"calls"` + } + mustLoadJSON(t, expectedPath, &expected) + + got := util.ParseToolCalls(fixture.Text, fixture.ToolNames) + if len(got) == 0 && len(expected.Calls) == 0 { + continue + } + if !reflect.DeepEqual(got, expected.Calls) { + t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls) + } + } +} + +func TestGoCompatTokenFixtures(t *testing.T) { + var fixture struct { + Cases []struct { + Name string `json:"name"` + Text string `json:"text"` + } `json:"cases"` + } + mustLoadJSON(t, compatPath("fixtures", "token_cases.json"), &fixture) + + var expected struct { + Cases []struct { + Name string `json:"name"` + Tokens int `json:"tokens"` + } `json:"cases"` + } + mustLoadJSON(t, compatPath("expected", "token_cases.json"), &expected) + + expectByName := map[string]int{} + for _, c := range expected.Cases { + expectByName[c.Name] = c.Tokens + } + for _, c := range fixture.Cases { + want, ok := expectByName[c.Name] + if !ok { + t.Fatalf("missing expected token case: %s", c.Name) + } + got := util.EstimateTokens(c.Text) + if got != want { + t.Fatalf("token fixture %s mismatch: got=%d want=%d", c.Name, got, want) + } + } +} + +func mustLoadJSON(t *testing.T, path string, out any) { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s failed: %v", path, err) + } + if err := json.Unmarshal(b, out); err != nil { + t.Fatalf("decode %s failed: %v", path, err) + } +} + +func trimExt(name string) string { + if len(name) > 5 && name[len(name)-5:] == ".json" { + return name[:len(name)-5] + } + return name +} + +func compatPath(parts ...string) string { + prefix := []string{"..", "..", "tests", "compat"} + return filepath.Join(append(prefix, parts...)...) +} diff --git a/internal/config/config.go b/internal/config/config.go index d391462..3bc0409 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "slices" + "strconv" "strings" "sync" ) @@ -63,6 +64,8 @@ type Config struct { ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` ModelAliases map[string]string `json:"model_aliases,omitempty"` + Admin AdminConfig `json:"admin,omitempty"` + Runtime RuntimeConfig `json:"runtime,omitempty"` Compat CompatConfig `json:"compat,omitempty"` Toolcall ToolcallConfig `json:"toolcall,omitempty"` Responses ResponsesConfig `json:"responses,omitempty"` @@ -76,6 +79,18 @@ type CompatConfig struct { WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` } +type AdminConfig struct { + PasswordHash string `json:"password_hash,omitempty"` + JWTExpireHours int `json:"jwt_expire_hours,omitempty"` + JWTValidAfterUnix int64 `json:"jwt_valid_after_unix,omitempty"` +} + +type RuntimeConfig struct { + AccountMaxInflight int `json:"account_max_inflight,omitempty"` + AccountMaxQueue int `json:"account_max_queue,omitempty"` + GlobalMaxInflight int `json:"global_max_inflight,omitempty"` +} + type ToolcallConfig struct { Mode string `json:"mode,omitempty"` EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"` @@ -109,6 +124,12 @@ func (c Config) MarshalJSON() ([]byte, error) { if len(c.ModelAliases) > 0 { m["model_aliases"] = c.ModelAliases } + if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 { + m["admin"] = c.Admin + } + if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 { + m["runtime"] = c.Runtime + } if c.Compat.WideInputStrictOutput != nil { m["compat"] = c.Compat } @@ -158,6 +179,14 @@ func (c *Config) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(v, &c.ModelAliases); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) } + case "admin": + if err := json.Unmarshal(v, &c.Admin); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "runtime": + if err := json.Unmarshal(v, &c.Runtime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "compat": if err := json.Unmarshal(v, &c.Compat); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) @@ -199,6 +228,8 @@ func (c Config) Clone() Config { ClaudeMapping: cloneStringMap(c.ClaudeMapping), ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), ModelAliases: cloneStringMap(c.ModelAliases), + Admin: c.Admin, + Runtime: c.Runtime, Compat: CompatConfig{ WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), }, @@ -621,3 +652,92 @@ func (s *Store) EmbeddingsProvider() string { defer s.mu.RUnlock() return strings.TrimSpace(s.cfg.Embeddings.Provider) } + +func (s *Store) AdminPasswordHash() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Admin.PasswordHash) +} + +func (s *Store) AdminJWTExpireHours() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Admin.JWTExpireHours > 0 { + return s.cfg.Admin.JWTExpireHours + } + if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" { + if n, err := strconv.Atoi(raw); err == nil && n > 0 { + return n + } + } + return 24 +} + +func (s *Store) AdminJWTValidAfterUnix() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Admin.JWTValidAfterUnix +} + +func (s *Store) RuntimeAccountMaxInflight() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxInflight > 0 { + return s.cfg.Runtime.AccountMaxInflight + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + return 2 +} + +func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxQueue > 0 { + return s.cfg.Runtime.AccountMaxQueue + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n >= 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} + +func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.GlobalMaxInflight > 0 { + return s.cfg.Runtime.GlobalMaxInflight + } + for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} diff --git a/internal/config/models.go b/internal/config/models.go index 017a2ee..a2ec899 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -10,6 +10,10 @@ type ModelInfo struct { Permission []any `json:"permission,omitempty"` } +type ModelAliasReader interface { + ModelAliases() map[string]string +} + var DeepSeekModels = []ModelInfo{ {ID: "deepseek-chat", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, {ID: "deepseek-reasoner", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, @@ -104,7 +108,7 @@ func DefaultModelAliases() map[string]string { } } -func ResolveModel(store *Store, requested string) (string, bool) { +func ResolveModel(store ModelAliasReader, requested string) (string, bool) { model := lower(strings.TrimSpace(requested)) if model == "" { return "", false @@ -172,7 +176,7 @@ func OpenAIModelsResponse() map[string]any { return map[string]any{"object": "list", "data": DeepSeekModels} } -func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) { +func OpenAIModelByID(store ModelAliasReader, id string) (ModelInfo, bool) { canonical, ok := ResolveModel(store, id) if !ok { return ModelInfo{}, false diff --git a/internal/deepseek/constants.go b/internal/deepseek/constants.go index 1e7d25f..042ec29 100644 --- a/internal/deepseek/constants.go +++ b/internal/deepseek/constants.go @@ -1,5 +1,10 @@ package deepseek +import ( + _ "embed" + "encoding/json" +) + const ( DeepSeekHost = "chat.deepseek.com" DeepSeekLoginURL = "https://chat.deepseek.com/api/v0/users/login" @@ -8,7 +13,7 @@ const ( DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion" ) -var BaseHeaders = map[string]string{ +var defaultBaseHeaders = map[string]string{ "Host": "chat.deepseek.com", "User-Agent": "DeepSeek/1.6.11 Android/35", "Accept": "application/json", @@ -19,6 +24,75 @@ var BaseHeaders = map[string]string{ "accept-charset": "UTF-8", } +var defaultSkipContainsPatterns = []string{ + "quasi_status", + "elapsed_secs", + "token_usage", + "pending_fragment", + "conversation_mode", + "fragments/-1/status", + "fragments/-2/status", + "fragments/-3/status", +} + +var defaultSkipExactPaths = []string{ + "response/search_status", +} + +var BaseHeaders = cloneStringMap(defaultBaseHeaders) +var SkipContainsPatterns = cloneStringSlice(defaultSkipContainsPatterns) +var SkipExactPathSet = toStringSet(defaultSkipExactPaths) + +type sharedConstants struct { + BaseHeaders map[string]string `json:"base_headers"` + SkipContainsPattern []string `json:"skip_contains_patterns"` + SkipExactPaths []string `json:"skip_exact_paths"` +} + +//go:embed constants_shared.json +var sharedConstantsJSON []byte + +func init() { + cfg := sharedConstants{} + if err := json.Unmarshal(sharedConstantsJSON, &cfg); err != nil { + return + } + if len(cfg.BaseHeaders) > 0 { + BaseHeaders = cloneStringMap(cfg.BaseHeaders) + } + if len(cfg.SkipContainsPattern) > 0 { + SkipContainsPatterns = cloneStringSlice(cfg.SkipContainsPattern) + } + if len(cfg.SkipExactPaths) > 0 { + SkipExactPathSet = toStringSet(cfg.SkipExactPaths) + } +} + +func cloneStringMap(in map[string]string) map[string]string { + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringSlice(in []string) []string { + out := make([]string, len(in)) + copy(out, in) + return out +} + +func toStringSet(in []string) map[string]struct{} { + out := make(map[string]struct{}, len(in)) + for _, v := range in { + if v == "" { + continue + } + out[v] = struct{}{} + } + return out +} + const ( KeepAliveTimeout = 5 StreamIdleTimeout = 30 diff --git a/internal/deepseek/constants_shared.json b/internal/deepseek/constants_shared.json new file mode 100644 index 0000000..a71ca02 --- /dev/null +++ b/internal/deepseek/constants_shared.json @@ -0,0 +1,25 @@ +{ + "base_headers": { + "Host": "chat.deepseek.com", + "User-Agent": "DeepSeek/1.6.11 Android/35", + "Accept": "application/json", + "Content-Type": "application/json", + "x-client-platform": "android", + "x-client-version": "1.6.11", + "x-client-locale": "zh_CN", + "accept-charset": "UTF-8" + }, + "skip_contains_patterns": [ + "quasi_status", + "elapsed_secs", + "token_usage", + "pending_fragment", + "conversation_mode", + "fragments/-1/status", + "fragments/-2/status", + "fragments/-3/status" + ], + "skip_exact_paths": [ + "response/search_status" + ] +} diff --git a/internal/deepseek/constants_test.go b/internal/deepseek/constants_test.go new file mode 100644 index 0000000..03c6788 --- /dev/null +++ b/internal/deepseek/constants_test.go @@ -0,0 +1,15 @@ +package deepseek + +import "testing" + +func TestSharedConstantsLoaded(t *testing.T) { + if BaseHeaders["x-client-platform"] != "android" { + t.Fatalf("unexpected base header x-client-platform=%q", BaseHeaders["x-client-platform"]) + } + if len(SkipContainsPatterns) == 0 { + t.Fatal("expected skip contains patterns to be loaded") + } + if _, ok := SkipExactPathSet["response/search_status"]; !ok { + t.Fatal("expected response/search_status in exact skip path set") + } +} diff --git a/internal/deepseek/prompt.go b/internal/deepseek/prompt.go new file mode 100644 index 0000000..2410390 --- /dev/null +++ b/internal/deepseek/prompt.go @@ -0,0 +1,7 @@ +package deepseek + +import "ds2api/internal/prompt" + +func MessagesPrepare(messages []map[string]any) string { + return prompt.MessagesPrepare(messages) +} diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go new file mode 100644 index 0000000..fdba055 --- /dev/null +++ b/internal/format/claude/render.go @@ -0,0 +1,46 @@ +package claude + +import ( + "fmt" + "time" + + "ds2api/internal/util" +) + +func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + content := make([]map[string]any, 0, 4) + if finalThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if finalText == "" { + finalText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": finalText}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), + "output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText), + }, + } +} diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go new file mode 100644 index 0000000..fc7473f --- /dev/null +++ b/internal/format/openai/render.go @@ -0,0 +1,193 @@ +package openai + +import ( + "strings" + "time" + + "github.com/google/uuid" + + "ds2api/internal/util" +) + +func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + } +} + +func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + exposedOutputText := finalText + output := make([]any, 0, 2) + if len(detected) > 0 { + exposedOutputText = "" + toolCalls := make([]any, 0, len(detected)) + for _, tc := range detected { + toolCalls = append(toolCalls, map[string]any{ + "type": "tool_call", + "name": tc.Name, + "arguments": tc.Input, + }) + } + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": toolCalls, + }) + } else { + content := []any{ + map[string]any{ + "type": "output_text", + "text": finalText, + }, + } + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": exposedOutputText, + "usage": map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + }, + } +} + +func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { + return map[string]any{ + "delta": delta, + "index": index, + } +} + +func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "index": index, + "finish_reason": finishReason, + } +} + +func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { + out := map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + } + if len(usage) > 0 { + out["usage"] = usage + } + return out +} + +func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + } +} + +func BuildResponsesCreatedPayload(responseID, model string) map[string]any { + return map[string]any{ + "type": "response.created", + "id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + } +} + +func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "delta": delta, + } +} + +func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.delta", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesCompletedPayload(response map[string]any) map[string]any { + return map[string]any{ + "type": "response.completed", + "response": response, + } +} diff --git a/internal/prompt/messages.go b/internal/prompt/messages.go new file mode 100644 index 0000000..69cfe5a --- /dev/null +++ b/internal/prompt/messages.go @@ -0,0 +1,84 @@ +package prompt + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) + +func MessagesPrepare(messages []map[string]any) string { + type block struct { + Role string + Text string + } + processed := make([]block, 0, len(messages)) + for _, m := range messages { + role, _ := m["role"].(string) + text := NormalizeContent(m["content"]) + processed = append(processed, block{Role: role, Text: text}) + } + if len(processed) == 0 { + return "" + } + merged := make([]block, 0, len(processed)) + for _, msg := range processed { + if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { + merged[len(merged)-1].Text += "\n\n" + msg.Text + continue + } + merged = append(merged, msg) + } + parts := make([]string, 0, len(merged)) + for i, m := range merged { + switch m.Role { + case "assistant": + parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") + case "user", "system": + if i > 0 { + parts = append(parts, "<|User|>"+m.Text) + } else { + parts = append(parts, m.Text) + } + default: + parts = append(parts, m.Text) + } + } + out := strings.Join(parts, "") + return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) +} + +func NormalizeContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + typeStr, _ := m["type"].(string) + typeStr = strings.ToLower(strings.TrimSpace(typeStr)) + if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { + if txt, ok := m["text"].(string); ok { + parts = append(parts, txt) + continue + } + if txt, ok := m["content"].(string); ok { + parts = append(parts, txt) + } + } + } + return strings.Join(parts, "\n") + default: + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) + } +} diff --git a/internal/sse/parser.go b/internal/sse/parser.go index 38429d9..c20bc79 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "strings" + + "ds2api/internal/deepseek" ) type ContentPart struct { @@ -11,11 +13,6 @@ type ContentPart struct { Type string } -var skipPatterns = []string{ - "quasi_status", "elapsed_secs", "token_usage", "pending_fragment", "conversation_mode", - "fragments/-1/status", "fragments/-2/status", "fragments/-3/status", -} - func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { line := strings.TrimSpace(string(raw)) if line == "" || !strings.HasPrefix(line, "data:") { @@ -33,10 +30,10 @@ func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { } func shouldSkipPath(path string) bool { - if path == "response/search_status" { + if _, ok := deepseek.SkipExactPathSet[path]; ok { return true } - for _, p := range skipPatterns { + for _, p := range deepseek.SkipContainsPatterns { if strings.Contains(path, p) { return true } @@ -60,126 +57,159 @@ func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, current } newType := currentFragmentType parts := make([]ContentPart, 0, 8) + collectDirectFragments(path, chunk, v, &newType, &parts) + updateTypeFromNestedResponse(path, v, &newType) + partType := resolvePartType(path, thinkingEnabled, newType) + finished := appendChunkValueContent(v, partType, &newType, &parts, path) + if finished { + return nil, true, newType + } + return parts, false, newType +} - // Newer DeepSeek responses may emit fragment APPEND directly on - // path "response/fragments" instead of wrapping it in path "response". - if path == "response/fragments" { - if op, _ := chunk["o"].(string); strings.EqualFold(op, "APPEND") { - if frags, ok := v.([]any); ok { - for _, frag := range frags { - fm, ok := frag.(map[string]any) - if !ok { - continue - } - t, _ := fm["type"].(string) - content, _ := fm["content"].(string) - t = strings.ToUpper(t) - switch t { - case "THINK", "THINKING": - newType = "thinking" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "thinking"}) - } - case "RESPONSE": - newType = "text" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - default: - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - } - } +func collectDirectFragments(path string, chunk map[string]any, v any, newType *string, parts *[]ContentPart) { + if path != "response/fragments" { + return + } + op, _ := chunk["o"].(string) + if !strings.EqualFold(op, "APPEND") { + return + } + frags, ok := v.([]any) + if !ok { + return + } + for _, frag := range frags { + m, ok := frag.(map[string]any) + if !ok { + continue + } + typeName, content, fragType := parseFragmentTypeContent(m) + if typeName == "" { + typeName = fragType + } + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + appendContentPart(parts, content, "thinking") + case "RESPONSE": + *newType = "text" + appendContentPart(parts, content, "text") + default: + appendContentPart(parts, content, "text") + } + } +} + +func updateTypeFromNestedResponse(path string, v any, newType *string) { + if path != "response" { + return + } + arr, ok := v.([]any) + if !ok { + return + } + for _, it := range arr { + m, ok := it.(map[string]any) + if !ok || m["p"] != "fragments" || m["o"] != "APPEND" { + continue + } + frags, ok := m["v"].([]any) + if !ok { + continue + } + for _, frag := range frags { + fm, ok := frag.(map[string]any) + if !ok { + continue + } + typeName, _, _ := parseFragmentTypeContent(fm) + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + case "RESPONSE": + *newType = "text" } } } +} - if path == "response" { - if arr, ok := v.([]any); ok { - for _, it := range arr { - m, ok := it.(map[string]any) - if !ok { - continue - } - if m["p"] == "fragments" && m["o"] == "APPEND" { - if frags, ok := m["v"].([]any); ok { - for _, frag := range frags { - fm, ok := frag.(map[string]any) - if !ok { - continue - } - t, _ := fm["type"].(string) - t = strings.ToUpper(t) - if t == "THINK" || t == "THINKING" { - newType = "thinking" - } else if t == "RESPONSE" { - newType = "text" - } - } - } - } - } - } - } - partType := "text" +func resolvePartType(path string, thinkingEnabled bool, newType string) string { switch { case path == "response/thinking_content": - partType = "thinking" + return "thinking" case path == "response/content": - partType = "text" + return "text" case strings.Contains(path, "response/fragments") && strings.Contains(path, "/content"): - partType = newType - case path == "": - if thinkingEnabled { - partType = newType - } + return newType + case path == "" && thinkingEnabled: + return newType + default: + return "text" } +} + +func appendChunkValueContent(v any, partType string, newType *string, parts *[]ContentPart, path string) bool { switch val := v.(type) { case string: if val == "FINISHED" && (path == "" || path == "status") { - return nil, true, newType - } - if val != "" { - parts = append(parts, ContentPart{Text: val, Type: partType}) + return true } + appendContentPart(parts, val, partType) case []any: pp, finished := extractContentRecursive(val, partType) if finished { - return nil, true, newType + return true } - parts = append(parts, pp...) + *parts = append(*parts, pp...) case map[string]any: - resp := val - if wrapped, ok := val["response"].(map[string]any); ok { - resp = wrapped + appendWrappedFragments(val, partType, newType, parts) + } + return false +} + +func appendWrappedFragments(val map[string]any, partType string, newType *string, parts *[]ContentPart) { + resp := val + if wrapped, ok := val["response"].(map[string]any); ok { + resp = wrapped + } + frags, ok := resp["fragments"].([]any) + if !ok { + return + } + for _, item := range frags { + m, ok := item.(map[string]any) + if !ok { + continue } - if frags, ok := resp["fragments"].([]any); ok { - for _, item := range frags { - m, ok := item.(map[string]any) - if !ok { - continue - } - t, _ := m["type"].(string) - content, _ := m["content"].(string) - t = strings.ToUpper(t) - if t == "THINK" || t == "THINKING" { - newType = "thinking" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "thinking"}) - } - } else if t == "RESPONSE" { - newType = "text" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - } else if content != "" { - parts = append(parts, ContentPart{Text: content, Type: partType}) - } - } + typeName, content, fragType := parseFragmentTypeContent(m) + if typeName == "" { + typeName = fragType + } + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + appendContentPart(parts, content, "thinking") + case "RESPONSE": + *newType = "text" + appendContentPart(parts, content, "text") + default: + appendContentPart(parts, content, partType) } } - return parts, false, newType +} + +func parseFragmentTypeContent(m map[string]any) (string, string, string) { + typeName, _ := m["type"].(string) + content, _ := m["content"].(string) + return strings.ToUpper(typeName), content, strings.ToUpper(typeName) +} + +func appendContentPart(parts *[]ContentPart, content, kind string) { + if content == "" { + return + } + *parts = append(*parts, ContentPart{Text: content, Type: kind}) } func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bool) { diff --git a/internal/stream/engine.go b/internal/stream/engine.go new file mode 100644 index 0000000..c63cd7b --- /dev/null +++ b/internal/stream/engine.go @@ -0,0 +1,128 @@ +package stream + +import ( + "context" + "io" + "time" + + "ds2api/internal/sse" +) + +type StopReason string + +const ( + StopReasonNone StopReason = "" + StopReasonContextCancelled StopReason = "context_cancelled" + StopReasonNoContentTimeout StopReason = "no_content_timeout" + StopReasonIdleTimeout StopReason = "idle_timeout" + StopReasonUpstreamCompleted StopReason = "upstream_completed" + StopReasonHandlerRequested StopReason = "handler_requested" +) + +type ConsumeConfig struct { + Context context.Context + Body io.Reader + ThinkingEnabled bool + InitialType string + KeepAliveInterval time.Duration + IdleTimeout time.Duration + MaxKeepAliveNoInput int +} + +type ParsedDecision struct { + Stop bool + StopReason StopReason + ContentSeen bool +} + +type ConsumeHooks struct { + OnParsed func(parsed sse.LineResult) ParsedDecision + OnKeepAlive func() + OnFinalize func(reason StopReason, scannerErr error) + OnContextDone func() +} + +func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) { + if cfg.Context == nil { + cfg.Context = context.Background() + } + initialType := cfg.InitialType + if initialType == "" { + if cfg.ThinkingEnabled { + initialType = "thinking" + } else { + initialType = "text" + } + } + parsedLines, done := sse.StartParsedLinePump(cfg.Context, cfg.Body, cfg.ThinkingEnabled, initialType) + + var ticker *time.Ticker + if cfg.KeepAliveInterval > 0 { + ticker = time.NewTicker(cfg.KeepAliveInterval) + defer ticker.Stop() + } + + hasContent := false + lastContent := time.Now() + keepaliveCount := 0 + + finalize := func(reason StopReason, scannerErr error) { + if hooks.OnFinalize != nil { + hooks.OnFinalize(reason, scannerErr) + } + } + + for { + select { + case <-cfg.Context.Done(): + if hooks.OnContextDone != nil { + hooks.OnContextDone() + } + return + case <-tickCh(ticker): + if !hasContent { + keepaliveCount++ + if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput { + finalize(StopReasonNoContentTimeout, nil) + return + } + } + if hasContent && cfg.IdleTimeout > 0 && time.Since(lastContent) > cfg.IdleTimeout { + finalize(StopReasonIdleTimeout, nil) + return + } + if hooks.OnKeepAlive != nil { + hooks.OnKeepAlive() + } + case parsed, ok := <-parsedLines: + if !ok { + finalize(StopReasonUpstreamCompleted, <-done) + return + } + if hooks.OnParsed == nil { + continue + } + decision := hooks.OnParsed(parsed) + if decision.ContentSeen { + hasContent = true + lastContent = time.Now() + keepaliveCount = 0 + } + if decision.Stop { + reason := decision.StopReason + if reason == StopReasonNone { + reason = StopReasonHandlerRequested + } + finalize(reason, nil) + return + } + } + } +} + +func tickCh(ticker *time.Ticker) <-chan time.Time { + if ticker == nil { + return nil + } + return ticker.C +} diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go index e6ae9a6..33e7580 100644 --- a/internal/testsuite/runner.go +++ b/internal/testsuite/runner.go @@ -327,7 +327,7 @@ func (r *Runner) runPreflight(ctx context.Context) error { {"go", "test", "./...", "-count=1"}, {"node", "--check", "api/chat-stream.js"}, {"node", "--check", "api/helpers/stream-tool-sieve.js"}, - {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js"}, + {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"}, {"npm", "run", "build", "--prefix", "webui"}, } f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) diff --git a/internal/util/messages.go b/internal/util/messages.go index fcc9484..b6920c0 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -1,16 +1,11 @@ package util import ( - "encoding/json" - "fmt" - "regexp" - "strings" - + "ds2api/internal/claudeconv" "ds2api/internal/config" + "ds2api/internal/prompt" ) -var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) - const ClaudeDefaultModel = "claude-sonnet-4-5" type Message struct { @@ -19,112 +14,15 @@ type Message struct { } func MessagesPrepare(messages []map[string]any) string { - type block struct { - Role string - Text string - } - processed := make([]block, 0, len(messages)) - for _, m := range messages { - role, _ := m["role"].(string) - text := normalizeContent(m["content"]) - processed = append(processed, block{Role: role, Text: text}) - } - if len(processed) == 0 { - return "" - } - merged := make([]block, 0, len(processed)) - for _, msg := range processed { - if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { - merged[len(merged)-1].Text += "\n\n" + msg.Text - continue - } - merged = append(merged, msg) - } - parts := make([]string, 0, len(merged)) - for i, m := range merged { - switch m.Role { - case "assistant": - parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") - case "user", "system": - if i > 0 { - parts = append(parts, "<|User|>"+m.Text) - } else { - parts = append(parts, m.Text) - } - default: - parts = append(parts, m.Text) - } - } - out := strings.Join(parts, "") - return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) + return prompt.MessagesPrepare(messages) } func normalizeContent(v any) string { - switch x := v.(type) { - case string: - return x - case []any: - parts := make([]string, 0, len(x)) - for _, item := range x { - m, ok := item.(map[string]any) - if !ok { - continue - } - typeStr, _ := m["type"].(string) - typeStr = strings.ToLower(strings.TrimSpace(typeStr)) - if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { - if txt, ok := m["text"].(string); ok { - parts = append(parts, txt) - continue - } - if txt, ok := m["content"].(string); ok { - parts = append(parts, txt) - } - } - } - return strings.Join(parts, "\n") - default: - b, err := json.Marshal(v) - if err != nil { - return fmt.Sprintf("%v", v) - } - return string(b) - } + return prompt.NormalizeContent(v) } func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[string]any { - messages, _ := claudeReq["messages"].([]any) - model, _ := claudeReq["model"].(string) - if model == "" { - model = ClaudeDefaultModel - } - mapping := store.ClaudeMapping() - dsModel := mapping["fast"] - if dsModel == "" { - dsModel = "deepseek-chat" - } - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { - if slow := mapping["slow"]; slow != "" { - dsModel = slow - } - } - convertedMessages := make([]any, 0, len(messages)+1) - if system, ok := claudeReq["system"].(string); ok && system != "" { - convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) - } - convertedMessages = append(convertedMessages, messages...) - - out := map[string]any{"model": dsModel, "messages": convertedMessages} - for _, k := range []string{"temperature", "top_p", "stream"} { - if v, ok := claudeReq[k]; ok { - out[k] = v - } - } - if stopSeq, ok := claudeReq["stop_sequences"]; ok { - out["stop"] = stopSeq - } - return out + return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, ClaudeDefaultModel) } // EstimateTokens provides a rough token count approximation. diff --git a/internal/util/render.go b/internal/util/render.go index b5e0a79..fff8501 100644 --- a/internal/util/render.go +++ b/internal/util/render.go @@ -8,6 +8,8 @@ import ( "github.com/google/uuid" ) +// BuildOpenAIChatCompletion is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatCompletion for new code. func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) finishReason := "stop" @@ -41,6 +43,8 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, } } +// BuildOpenAIResponseObject is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponseObject for new code. func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) exposedOutputText := finalText @@ -101,6 +105,8 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi } } +// BuildClaudeMessageResponse is kept for backward compatibility. +// Prefer internal/format/claude.BuildMessageResponse for new code. func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { detected := ParseToolCalls(finalText, toolNames) content := make([]map[string]any, 0, 4) diff --git a/internal/util/render_stream.go b/internal/util/render_stream.go index 716c158..b5699ba 100644 --- a/internal/util/render_stream.go +++ b/internal/util/render_stream.go @@ -1,5 +1,7 @@ package util +// BuildOpenAIChatStreamDeltaChoice is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatStreamDeltaChoice for new code. func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { return map[string]any{ "delta": delta, @@ -7,6 +9,8 @@ func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[strin } } +// BuildOpenAIChatStreamFinishChoice is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatStreamFinishChoice for new code. func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any { return map[string]any{ "delta": map[string]any{}, @@ -15,6 +19,8 @@ func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[strin } } +// BuildOpenAIChatStreamChunk is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatStreamChunk for new code. func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { out := map[string]any{ "id": completionID, @@ -29,6 +35,8 @@ func BuildOpenAIChatStreamChunk(completionID string, created int64, model string return out } +// BuildOpenAIChatUsage is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatUsage for new code. func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { promptTokens := EstimateTokens(finalPrompt) reasoningTokens := EstimateTokens(finalThinking) @@ -43,6 +51,8 @@ func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[stri } } +// BuildOpenAIResponsesCreatedPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesCreatedPayload for new code. func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any { return map[string]any{ "type": "response.created", @@ -53,6 +63,8 @@ func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any } } +// BuildOpenAIResponsesTextDeltaPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesTextDeltaPayload for new code. func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ "type": "response.output_text.delta", @@ -61,6 +73,8 @@ func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]a } } +// BuildOpenAIResponsesReasoningDeltaPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesReasoningDeltaPayload for new code. func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ "type": "response.reasoning.delta", @@ -69,6 +83,8 @@ func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[str } } +// BuildOpenAIResponsesToolCallDeltaPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesToolCallDeltaPayload for new code. func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { return map[string]any{ "type": "response.output_tool_call.delta", @@ -77,6 +93,8 @@ func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map } } +// BuildOpenAIResponsesToolCallDonePayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesToolCallDonePayload for new code. func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { return map[string]any{ "type": "response.output_tool_call.done", @@ -85,6 +103,8 @@ func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[ } } +// BuildOpenAIResponsesCompletedPayload is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponsesCompletedPayload for new code. func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any { return map[string]any{ "type": "response.completed", diff --git a/tests/compat/expected/sse_fragments_append.json b/tests/compat/expected/sse_fragments_append.json new file mode 100644 index 0000000..8647f3a --- /dev/null +++ b/tests/compat/expected/sse_fragments_append.json @@ -0,0 +1,8 @@ +{ + "parts": [ + {"text": "思考中", "type": "thinking"}, + {"text": "结论", "type": "text"} + ], + "finished": false, + "new_type": "text" +} diff --git a/tests/compat/expected/sse_nested_finished.json b/tests/compat/expected/sse_nested_finished.json new file mode 100644 index 0000000..7d588f7 --- /dev/null +++ b/tests/compat/expected/sse_nested_finished.json @@ -0,0 +1,5 @@ +{ + "parts": [], + "finished": true, + "new_type": "text" +} diff --git a/tests/compat/expected/sse_split_tool_json.json b/tests/compat/expected/sse_split_tool_json.json new file mode 100644 index 0000000..2afed2a --- /dev/null +++ b/tests/compat/expected/sse_split_tool_json.json @@ -0,0 +1,8 @@ +{ + "parts": [ + {"text": "{\"", "type": "text"}, + {"text": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", "type": "text"} + ], + "finished": false, + "new_type": "text" +} diff --git a/tests/compat/expected/token_cases.json b/tests/compat/expected/token_cases.json new file mode 100644 index 0000000..69694eb --- /dev/null +++ b/tests/compat/expected/token_cases.json @@ -0,0 +1,7 @@ +{ + "cases": [ + {"name": "ascii_short", "tokens": 1}, + {"name": "cjk", "tokens": 3}, + {"name": "mixed", "tokens": 4} + ] +} diff --git a/tests/compat/expected/toolcalls_fenced_json.json b/tests/compat/expected/toolcalls_fenced_json.json new file mode 100644 index 0000000..97646bf --- /dev/null +++ b/tests/compat/expected/toolcalls_fenced_json.json @@ -0,0 +1,3 @@ +{ + "calls": [] +} diff --git a/tests/compat/expected/toolcalls_unknown_name.json b/tests/compat/expected/toolcalls_unknown_name.json new file mode 100644 index 0000000..8f79875 --- /dev/null +++ b/tests/compat/expected/toolcalls_unknown_name.json @@ -0,0 +1,5 @@ +{ + "calls": [ + {"name": "unknown_tool", "input": {"x": 1}} + ] +} diff --git a/tests/compat/fixtures/sse_chunks/fragments_append.json b/tests/compat/fixtures/sse_chunks/fragments_append.json new file mode 100644 index 0000000..c6f8ae6 --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/fragments_append.json @@ -0,0 +1,12 @@ +{ + "chunk": { + "p": "response/fragments", + "o": "APPEND", + "v": [ + {"type": "THINK", "content": "思考中"}, + {"type": "RESPONSE", "content": "结论"} + ] + }, + "thinking_enabled": true, + "current_type": "thinking" +} diff --git a/tests/compat/fixtures/sse_chunks/nested_finished.json b/tests/compat/fixtures/sse_chunks/nested_finished.json new file mode 100644 index 0000000..da76280 --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/nested_finished.json @@ -0,0 +1,10 @@ +{ + "chunk": { + "p": "response", + "v": [ + {"p": "status", "v": "FINISHED"} + ] + }, + "thinking_enabled": false, + "current_type": "text" +} diff --git a/tests/compat/fixtures/sse_chunks/split_tool_json.json b/tests/compat/fixtures/sse_chunks/split_tool_json.json new file mode 100644 index 0000000..e915fbb --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/split_tool_json.json @@ -0,0 +1,11 @@ +{ + "chunk": { + "p": "response", + "v": [ + {"p": "response/content", "v": "{\""}, + {"p": "response/content", "v": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"} + ] + }, + "thinking_enabled": false, + "current_type": "text" +} diff --git a/tests/compat/fixtures/token_cases.json b/tests/compat/fixtures/token_cases.json new file mode 100644 index 0000000..3887356 --- /dev/null +++ b/tests/compat/fixtures/token_cases.json @@ -0,0 +1,7 @@ +{ + "cases": [ + {"name": "ascii_short", "text": "abcd"}, + {"name": "cjk", "text": "你好世界"}, + {"name": "mixed", "text": "Hello 你好世界"} + ] +} diff --git a/tests/compat/fixtures/toolcalls/fenced_json.json b/tests/compat/fixtures/toolcalls/fenced_json.json new file mode 100644 index 0000000..8d75cc1 --- /dev/null +++ b/tests/compat/fixtures/toolcalls/fenced_json.json @@ -0,0 +1,4 @@ +{ + "text": "```json\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}\n```", + "tool_names": ["read_file"] +} diff --git a/tests/compat/fixtures/toolcalls/unknown_name.json b/tests/compat/fixtures/toolcalls/unknown_name.json new file mode 100644 index 0000000..0ba9e76 --- /dev/null +++ b/tests/compat/fixtures/toolcalls/unknown_name.json @@ -0,0 +1,4 @@ +{ + "text": "{\"tool_calls\":[{\"name\":\"unknown_tool\",\"input\":{\"x\":1}}]}", + "tool_names": ["read_file"] +} diff --git a/webui/src/App.jsx b/webui/src/App.jsx index 53d0b4a..3f6ad27 100644 --- a/webui/src/App.jsx +++ b/webui/src/App.jsx @@ -11,6 +11,7 @@ import { Key, Upload, Cloud, + Settings as SettingsIcon, LogOut, Menu, X, @@ -23,12 +24,13 @@ import AccountManager from './components/AccountManager' import ApiTester from './components/ApiTester' import BatchImport from './components/BatchImport' import VercelSync from './components/VercelSync' +import Settings from './components/Settings' import Login from './components/Login' import LandingPage from './components/LandingPage' import LanguageToggle from './components/LanguageToggle' import { useI18n } from './i18n' -function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message }) { +function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) { const { t } = useI18n() const [activeTab, setActiveTab] = useState('accounts') const [sidebarOpen, setSidebarOpen] = useState(false) @@ -39,6 +41,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message { id: 'test', label: t('nav.test.label'), icon: Server, description: t('nav.test.desc') }, { id: 'import', label: t('nav.import.label'), icon: Upload, description: t('nav.import.desc') }, { id: 'vercel', label: t('nav.vercel.label'), icon: Cloud, description: t('nav.vercel.desc') }, + { id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') }, ] const authFetch = async (url, options = {}) => { @@ -65,6 +68,8 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message return case 'vercel': return + case 'settings': + return default: return null } @@ -314,6 +319,7 @@ export default function App() { fetchConfig={fetchConfig} showMessage={showMessage} message={message} + onForceLogout={handleLogout} /> ) : (
diff --git a/webui/src/components/Settings.jsx b/webui/src/components/Settings.jsx new file mode 100644 index 0000000..b257ed5 --- /dev/null +++ b/webui/src/components/Settings.jsx @@ -0,0 +1,376 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' +import { AlertTriangle, Download, Lock, Save, Upload } from 'lucide-react' +import { useI18n } from '../i18n' + +export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout }) { + const { t } = useI18n() + const apiFetch = authFetch || fetch + + const [loading, setLoading] = useState(false) + const [saving, setSaving] = useState(false) + const [changingPassword, setChangingPassword] = useState(false) + const [importing, setImporting] = useState(false) + const [exportData, setExportData] = useState(null) + const [importMode, setImportMode] = useState('merge') + const [importText, setImportText] = useState('') + const [newPassword, setNewPassword] = useState('') + const [settingsMeta, setSettingsMeta] = useState({ default_password_warning: false, env_backed: false, needs_vercel_sync: false }) + + const [form, setForm] = useState({ + admin: { jwt_expire_hours: 24 }, + runtime: { account_max_inflight: 2, account_max_queue: 10, global_max_inflight: 10 }, + toolcall: { mode: 'feature_match', early_emit_confidence: 'high' }, + responses: { store_ttl_seconds: 900 }, + embeddings: { provider: '' }, + claude_mapping_text: '{\n "fast": "deepseek-chat",\n "slow": "deepseek-reasoner"\n}', + model_aliases_text: '{}', + }) + + const parseJSONMap = (raw, fieldName) => { + const text = String(raw || '').trim() + if (!text) { + return {} + } + let parsed + try { + parsed = JSON.parse(text) + } catch (_e) { + throw new Error(t('settings.invalidJsonField', { field: fieldName })) + } + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error(t('settings.invalidJsonField', { field: fieldName })) + } + return parsed + } + + const loadSettings = useCallback(async () => { + setLoading(true) + try { + const res = await apiFetch('/admin/settings') + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.loadFailed')) + return + } + setSettingsMeta({ + default_password_warning: Boolean(data.admin?.default_password_warning), + env_backed: Boolean(data.env_backed), + needs_vercel_sync: Boolean(data.needs_vercel_sync), + }) + setForm({ + admin: { jwt_expire_hours: Number(data.admin?.jwt_expire_hours || 24) }, + runtime: { + account_max_inflight: Number(data.runtime?.account_max_inflight || 2), + account_max_queue: Number(data.runtime?.account_max_queue || 10), + global_max_inflight: Number(data.runtime?.global_max_inflight || 10), + }, + toolcall: { + mode: data.toolcall?.mode || 'feature_match', + early_emit_confidence: data.toolcall?.early_emit_confidence || 'high', + }, + responses: { + store_ttl_seconds: Number(data.responses?.store_ttl_seconds || 900), + }, + embeddings: { + provider: data.embeddings?.provider || '', + }, + claude_mapping_text: JSON.stringify(data.claude_mapping || {}, null, 2), + model_aliases_text: JSON.stringify(data.model_aliases || {}, null, 2), + }) + } catch (e) { + onMessage('error', t('settings.loadFailed')) + // eslint-disable-next-line no-console + console.error(e) + } finally { + setLoading(false) + } + }, [apiFetch, onMessage, t]) + + useEffect(() => { + loadSettings() + }, [loadSettings]) + + const saveSettings = async () => { + let claudeMapping = {} + let modelAliases = {} + try { + claudeMapping = parseJSONMap(form.claude_mapping_text, 'claude_mapping') + modelAliases = parseJSONMap(form.model_aliases_text, 'model_aliases') + } catch (e) { + onMessage('error', e.message) + return + } + + const payload = { + admin: { jwt_expire_hours: Number(form.admin.jwt_expire_hours) }, + runtime: { + account_max_inflight: Number(form.runtime.account_max_inflight), + account_max_queue: Number(form.runtime.account_max_queue), + global_max_inflight: Number(form.runtime.global_max_inflight), + }, + toolcall: { + mode: String(form.toolcall.mode || '').trim(), + early_emit_confidence: String(form.toolcall.early_emit_confidence || '').trim(), + }, + responses: { store_ttl_seconds: Number(form.responses.store_ttl_seconds) }, + embeddings: { provider: String(form.embeddings.provider || '').trim() }, + claude_mapping: claudeMapping, + model_aliases: modelAliases, + } + + setSaving(true) + try { + const res = await apiFetch('/admin/settings', { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload), + }) + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.saveFailed')) + return + } + onMessage('success', t('settings.saveSuccess')) + if (typeof onRefresh === 'function') { + onRefresh() + } + await loadSettings() + } catch (e) { + onMessage('error', t('settings.saveFailed')) + // eslint-disable-next-line no-console + console.error(e) + } finally { + setSaving(false) + } + } + + const updatePassword = async () => { + if (String(newPassword || '').trim().length < 4) { + onMessage('error', t('settings.passwordTooShort')) + return + } + setChangingPassword(true) + try { + const res = await apiFetch('/admin/settings/password', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ new_password: newPassword.trim() }), + }) + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.passwordUpdateFailed')) + return + } + onMessage('success', t('settings.passwordUpdated')) + setNewPassword('') + if (typeof onForceLogout === 'function') { + onForceLogout() + } + } catch (e) { + onMessage('error', t('settings.passwordUpdateFailed')) + } finally { + setChangingPassword(false) + } + } + + const loadExportData = async () => { + try { + const res = await apiFetch('/admin/config/export') + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.exportFailed')) + return + } + setExportData(data) + onMessage('success', t('settings.exportLoaded')) + } catch (e) { + onMessage('error', t('settings.exportFailed')) + } + } + + const doImport = async () => { + if (!String(importText || '').trim()) { + onMessage('error', t('settings.importEmpty')) + return + } + let parsed + try { + parsed = JSON.parse(importText) + } catch (_e) { + onMessage('error', t('settings.importInvalidJson')) + return + } + setImporting(true) + try { + const res = await apiFetch(`/admin/config/import?mode=${encodeURIComponent(importMode)}`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ config: parsed, mode: importMode }), + }) + const data = await res.json() + if (!res.ok) { + onMessage('error', data.detail || t('settings.importFailed')) + return + } + onMessage('success', t('settings.importSuccess', { mode: importMode })) + if (typeof onRefresh === 'function') { + onRefresh() + } + await loadSettings() + } catch (e) { + onMessage('error', t('settings.importFailed')) + } finally { + setImporting(false) + } + } + + const syncHintVisible = useMemo(() => settingsMeta.env_backed || settingsMeta.needs_vercel_sync, [settingsMeta.env_backed, settingsMeta.needs_vercel_sync]) + + return ( +
+ {settingsMeta.default_password_warning && ( +
+ + {t('settings.defaultPasswordWarning')} +
+ )} + {syncHintVisible && ( +
+ + {t('settings.vercelSyncHint')} +
+ )} + +
+

{t('settings.securityTitle')}

+
+ + +
+
+ +
+

{t('settings.runtimeTitle')}

+
+ + + +
+
+ +
+

{t('settings.behaviorTitle')}

+
+ + + + +
+
+ +
+

{t('settings.modelTitle')}

+
+