From 416b9939fc0de6f55dc06e60319f26b53cf8aaad Mon Sep 17 00:00:00 2001 From: CJACK Date: Tue, 17 Feb 2026 01:35:10 +0800 Subject: [PATCH] Refactor admin handlers into specialized files and introduce OpenAI tool sieving and Vercel streaming capabilities. --- api/chat-stream.js | 440 +----------- api/helpers/stream-tool-sieve.js | 456 ++++++++++++ internal/adapter/openai/handler.go | 516 -------------- internal/adapter/openai/tool_sieve.go | 236 +++++++ internal/adapter/openai/vercel_stream.go | 277 ++++++++ internal/admin/handler.go | 864 ----------------------- internal/admin/handler_accounts.go | 310 ++++++++ internal/admin/handler_auth.go | 69 ++ internal/admin/handler_config.go | 240 +++++++ internal/admin/handler_vercel.go | 197 ++++++ internal/admin/helpers.go | 98 +++ 11 files changed, 1890 insertions(+), 1813 deletions(-) create mode 100644 api/helpers/stream-tool-sieve.js create mode 100644 internal/adapter/openai/tool_sieve.go create mode 100644 internal/adapter/openai/vercel_stream.go create mode 100644 internal/admin/handler_accounts.go create mode 100644 internal/admin/handler_auth.go create mode 100644 internal/admin/handler_config.go create mode 100644 internal/admin/handler_vercel.go create mode 100644 internal/admin/helpers.go diff --git a/api/chat-stream.js b/api/chat-stream.js index 93cb05f..f69300e 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -1,6 +1,12 @@ 'use strict'; -const crypto = require('crypto'); +const { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + formatOpenAIStreamToolCalls, +} = require('./helpers/stream-tool-sieve'); const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; @@ -675,435 +681,3 @@ function asString(v) { } return String(v).trim(); } - -function extractToolNames(tools) { - if (!Array.isArray(tools) || tools.length === 0) { - return []; - } - const out = []; - for (const t of tools) { - if (!t || typeof t !== 'object') { - continue; - } - const fn = t.function && typeof t.function === 'object' ? t.function : t; - const name = asString(fn.name); - if (name) { - out.push(name); - } - } - return out; -} - -function createToolSieveState() { - return { - pending: '', - capture: '', - capturing: false, - }; -} - -function processToolSieveChunk(state, chunk, toolNames) { - if (!state) { - return []; - } - if (chunk) { - state.pending += chunk; - } - const events = []; - // eslint-disable-next-line no-constant-condition - while (true) { - if (state.capturing) { - if (state.pending) { - state.capture += state.pending; - state.pending = ''; - } - const consumed = consumeToolCapture(state.capture, toolNames); - if (!consumed.ready) { - break; - } - state.capture = ''; - state.capturing = false; - if (consumed.prefix) { - events.push({ type: 'text', text: consumed.prefix }); - } - if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { - events.push({ type: 'tool_calls', calls: consumed.calls }); - } - if (consumed.suffix) { - state.pending += consumed.suffix; - } - continue; - } - - if (!state.pending) { - break; - } - - const start = findToolSegmentStart(state.pending); - if (start >= 0) { - const prefix = state.pending.slice(0, start); - if (prefix) { - events.push({ type: 'text', text: prefix }); - } - state.capture = state.pending.slice(start); - state.pending = ''; - state.capturing = true; - continue; - } - - const [safe, hold] = splitSafeContentForToolDetection(state.pending); - if (!safe) { - break; - } - state.pending = hold; - events.push({ type: 'text', text: safe }); - } - return events; -} - -function flushToolSieve(state, toolNames) { - if (!state) { - return []; - } - const events = processToolSieveChunk(state, '', toolNames); - if (state.capturing) { - const consumed = consumeToolCapture(state.capture, toolNames); - if (consumed.ready) { - if (consumed.prefix) { - events.push({ type: 'text', text: consumed.prefix }); - } - if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { - events.push({ type: 'tool_calls', calls: consumed.calls }); - } - if (consumed.suffix) { - events.push({ type: 'text', text: consumed.suffix }); - } - } else if (state.capture) { - events.push({ type: 'text', text: state.capture }); - } - state.capture = ''; - state.capturing = false; - } - if (state.pending) { - events.push({ type: 'text', text: state.pending }); - state.pending = ''; - } - return events; -} - -function splitSafeContentForToolDetection(s) { - const text = s || ''; - if (!text) { - return ['', '']; - } - const suspiciousStart = findSuspiciousPrefixStart(text); - if (suspiciousStart < 0) { - return [text, '']; - } - if (suspiciousStart > 0) { - return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; - } - const chars = Array.from(text); - const maxHold = 128; - if (chars.length <= maxHold) { - return ['', text]; - } - return [chars.slice(0, chars.length - maxHold).join(''), chars.slice(chars.length - maxHold).join('')]; -} - -function findSuspiciousPrefixStart(s) { - let start = -1; - for (const needle of ['{', '[', '```']) { - const idx = s.lastIndexOf(needle); - if (idx > start) { - start = idx; - } - } - return start; -} - -function findToolSegmentStart(s) { - if (!s) { - return -1; - } - const lower = s.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return -1; - } - const start = s.slice(0, keyIdx).lastIndexOf('{'); - return start >= 0 ? start : keyIdx; -} - -function consumeToolCapture(captured, toolNames) { - if (!captured) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const lower = captured.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - if (Array.from(captured).length >= 256) { - return { ready: true, prefix: captured, calls: [], suffix: '' }; - } - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const start = captured.slice(0, keyIdx).lastIndexOf('{'); - if (start < 0) { - if (Array.from(captured).length >= 512) { - return { ready: true, prefix: captured, calls: [], suffix: '' }; - } - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const obj = extractJSONObjectFrom(captured, start); - if (!obj.ok) { - if (Array.from(captured).length >= 4096) { - return { ready: true, prefix: captured, calls: [], suffix: '' }; - } - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames); - if (parsed.length === 0) { - return { - ready: true, - prefix: captured.slice(0, obj.end), - calls: [], - suffix: captured.slice(obj.end), - }; - } - return { - ready: true, - prefix: captured.slice(0, start), - calls: parsed, - suffix: captured.slice(obj.end), - }; -} - -function extractJSONObjectFrom(text, start) { - if (!text || start < 0 || start >= text.length || text[start] !== '{') { - return { ok: false, end: 0 }; - } - let depth = 0; - let quote = ''; - let escaped = false; - for (let i = start; i < text.length; i += 1) { - const ch = text[i]; - if (quote) { - if (escaped) { - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === quote) { - quote = ''; - } - continue; - } - if (ch === '"' || ch === "'") { - quote = ch; - continue; - } - if (ch === '{') { - depth += 1; - continue; - } - if (ch === '}') { - depth -= 1; - if (depth === 0) { - return { ok: true, end: i + 1 }; - } - } - } - return { ok: false, end: 0 }; -} - -function parseToolCalls(text, toolNames) { - if (!asString(text)) { - return []; - } - const candidates = buildToolCallCandidates(text); - let parsed = []; - for (const c of candidates) { - parsed = parseToolCallsPayload(c); - if (parsed.length > 0) { - break; - } - } - if (parsed.length === 0) { - return []; - } - const allowed = new Set((toolNames || []).filter(Boolean)); - const out = []; - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (allowed.size > 0 && !allowed.has(tc.name)) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); - } - if (out.length === 0 && parsed.length > 0) { - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); - } - } - return out; -} - -function buildToolCallCandidates(text) { - const trimmed = asString(text); - const candidates = [trimmed]; - const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || []; - for (const block of fenced) { - const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); - if (m && m[1]) { - candidates.push(asString(m[1])); - } - } - const keyIdx = trimmed.toLowerCase().indexOf('tool_calls'); - if (keyIdx >= 0) { - const start = trimmed.slice(0, keyIdx).lastIndexOf('{'); - if (start >= 0) { - const obj = extractJSONObjectFrom(trimmed, start); - if (obj.ok) { - candidates.push(asString(trimmed.slice(start, obj.end))); - } - } - } - const first = trimmed.indexOf('{'); - const last = trimmed.lastIndexOf('}'); - if (first >= 0 && last > first) { - candidates.push(asString(trimmed.slice(first, last + 1))); - } - return [...new Set(candidates.filter(Boolean))]; -} - -function parseToolCallsPayload(payload) { - let decoded; - try { - decoded = JSON.parse(payload); - } catch (_err) { - return []; - } - if (Array.isArray(decoded)) { - return parseToolCallList(decoded); - } - if (!decoded || typeof decoded !== 'object') { - return []; - } - if (decoded.tool_calls) { - return parseToolCallList(decoded.tool_calls); - } - const one = parseToolCallItem(decoded); - return one ? [one] : []; -} - -function parseToolCallList(v) { - if (!Array.isArray(v)) { - return []; - } - const out = []; - for (const item of v) { - if (!item || typeof item !== 'object') { - continue; - } - const one = parseToolCallItem(item); - if (one) { - out.push(one); - } - } - return out; -} - -function parseToolCallItem(m) { - let name = asString(m.name); - let inputRaw = m.input; - let hasInput = Object.prototype.hasOwnProperty.call(m, 'input'); - const fn = m.function && typeof m.function === 'object' ? m.function : null; - if (fn) { - if (!name) { - name = asString(fn.name); - } - if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) { - inputRaw = fn.arguments; - hasInput = true; - } - } - if (!hasInput) { - for (const k of ['arguments', 'args', 'parameters', 'params']) { - if (Object.prototype.hasOwnProperty.call(m, k)) { - inputRaw = m[k]; - hasInput = true; - break; - } - } - } - if (!name) { - return null; - } - return { - name, - input: parseToolCallInput(inputRaw), - }; -} - -function parseToolCallInput(v) { - if (v == null) { - return {}; - } - if (typeof v === 'string') { - const raw = asString(v); - if (!raw) { - return {}; - } - try { - const parsed = JSON.parse(raw); - if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { - return parsed; - } - } catch (_err) { - return { _raw: raw }; - } - return {}; - } - if (typeof v === 'object' && !Array.isArray(v)) { - return v; - } - try { - const parsed = JSON.parse(JSON.stringify(v)); - if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { - return parsed; - } - } catch (_err) { - return {}; - } - return {}; -} - -function formatOpenAIStreamToolCalls(calls) { - if (!Array.isArray(calls) || calls.length === 0) { - return []; - } - return calls.map((c, idx) => ({ - index: idx, - id: `call_${newCallID()}`, - type: 'function', - function: { - name: c.name, - arguments: JSON.stringify(c.input || {}), - }, - })); -} - -function newCallID() { - if (typeof crypto.randomUUID === 'function') { - return crypto.randomUUID().replace(/-/g, ''); - } - return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; -} diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js new file mode 100644 index 0000000..07d8cad --- /dev/null +++ b/api/helpers/stream-tool-sieve.js @@ -0,0 +1,456 @@ +'use strict'; + +const crypto = require('crypto'); + +function extractToolNames(tools) { + if (!Array.isArray(tools) || tools.length === 0) { + return []; + } + const out = []; + for (const t of tools) { + if (!t || typeof t !== 'object') { + continue; + } + const fn = t.function && typeof t.function === 'object' ? t.function : t; + const name = toStringSafe(fn.name); + if (name) { + out.push(name); + } + } + return out; +} + +function createToolSieveState() { + return { + pending: '', + capture: '', + capturing: false, + }; +} + +function processToolSieveChunk(state, chunk, toolNames) { + if (!state) { + return []; + } + if (chunk) { + state.pending += chunk; + } + const events = []; + // eslint-disable-next-line no-constant-condition + while (true) { + if (state.capturing) { + if (state.pending) { + state.capture += state.pending; + state.pending = ''; + } + const consumed = consumeToolCapture(state.capture, toolNames); + if (!consumed.ready) { + break; + } + state.capture = ''; + state.capturing = false; + if (consumed.prefix) { + events.push({ type: 'text', text: consumed.prefix }); + } + if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { + events.push({ type: 'tool_calls', calls: consumed.calls }); + } + if (consumed.suffix) { + state.pending += consumed.suffix; + } + continue; + } + + if (!state.pending) { + break; + } + + const start = findToolSegmentStart(state.pending); + if (start >= 0) { + const prefix = state.pending.slice(0, start); + if (prefix) { + events.push({ type: 'text', text: prefix }); + } + state.capture = state.pending.slice(start); + state.pending = ''; + state.capturing = true; + continue; + } + + const [safe, hold] = splitSafeContentForToolDetection(state.pending); + if (!safe) { + break; + } + state.pending = hold; + events.push({ type: 'text', text: safe }); + } + return events; +} + +function flushToolSieve(state, toolNames) { + if (!state) { + return []; + } + const events = processToolSieveChunk(state, '', toolNames); + if (state.capturing) { + const consumed = consumeToolCapture(state.capture, toolNames); + if (consumed.ready) { + if (consumed.prefix) { + events.push({ type: 'text', text: consumed.prefix }); + } + if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { + events.push({ type: 'tool_calls', calls: consumed.calls }); + } + if (consumed.suffix) { + events.push({ type: 'text', text: consumed.suffix }); + } + } else if (state.capture) { + events.push({ type: 'text', text: state.capture }); + } + state.capture = ''; + state.capturing = false; + } + if (state.pending) { + events.push({ type: 'text', text: state.pending }); + state.pending = ''; + } + return events; +} + +function splitSafeContentForToolDetection(s) { + const text = s || ''; + if (!text) { + return ['', '']; + } + const suspiciousStart = findSuspiciousPrefixStart(text); + if (suspiciousStart < 0) { + return [text, '']; + } + if (suspiciousStart > 0) { + return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; + } + const chars = Array.from(text); + const maxHold = 128; + if (chars.length <= maxHold) { + return ['', text]; + } + return [chars.slice(0, chars.length - maxHold).join(''), chars.slice(chars.length - maxHold).join('')]; +} + +function findSuspiciousPrefixStart(s) { + let start = -1; + for (const needle of ['{', '[', '```']) { + const idx = s.lastIndexOf(needle); + if (idx > start) { + start = idx; + } + } + return start; +} + +function findToolSegmentStart(s) { + if (!s) { + return -1; + } + const lower = s.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return -1; + } + const start = s.slice(0, keyIdx).lastIndexOf('{'); + return start >= 0 ? start : keyIdx; +} + +function consumeToolCapture(captured, toolNames) { + if (!captured) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + if (Array.from(captured).length >= 256) { + return { ready: true, prefix: captured, calls: [], suffix: '' }; + } + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0) { + if (Array.from(captured).length >= 512) { + return { ready: true, prefix: captured, calls: [], suffix: '' }; + } + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const obj = extractJSONObjectFrom(captured, start); + if (!obj.ok) { + if (Array.from(captured).length >= 4096) { + return { ready: true, prefix: captured, calls: [], suffix: '' }; + } + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames); + if (parsed.length === 0) { + return { + ready: true, + prefix: captured.slice(0, obj.end), + calls: [], + suffix: captured.slice(obj.end), + }; + } + return { + ready: true, + prefix: captured.slice(0, start), + calls: parsed, + suffix: captured.slice(obj.end), + }; +} + +function extractJSONObjectFrom(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '{') { + return { ok: false, end: 0 }; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + return { ok: true, end: i + 1 }; + } + } + } + return { ok: false, end: 0 }; +} + +function parseToolCalls(text, toolNames) { + if (!toStringSafe(text)) { + return []; + } + const candidates = buildToolCallCandidates(text); + let parsed = []; + for (const c of candidates) { + parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + break; + } + } + if (parsed.length === 0) { + return []; + } + const allowed = new Set((toolNames || []).filter(Boolean)); + const out = []; + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + if (allowed.size > 0 && !allowed.has(tc.name)) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + if (out.length === 0 && parsed.length > 0) { + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + } + return out; +} + +function buildToolCallCandidates(text) { + const trimmed = toStringSafe(text); + const candidates = [trimmed]; + const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || []; + for (const block of fenced) { + const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); + } + } + const keyIdx = trimmed.toLowerCase().indexOf('tool_calls'); + if (keyIdx >= 0) { + const start = trimmed.slice(0, keyIdx).lastIndexOf('{'); + if (start >= 0) { + const obj = extractJSONObjectFrom(trimmed, start); + if (obj.ok) { + candidates.push(toStringSafe(trimmed.slice(start, obj.end))); + } + } + } + const first = trimmed.indexOf('{'); + const last = trimmed.lastIndexOf('}'); + if (first >= 0 && last > first) { + candidates.push(toStringSafe(trimmed.slice(first, last + 1))); + } + return [...new Set(candidates.filter(Boolean))]; +} + +function parseToolCallsPayload(payload) { + let decoded; + try { + decoded = JSON.parse(payload); + } catch (_err) { + return []; + } + if (Array.isArray(decoded)) { + return parseToolCallList(decoded); + } + if (!decoded || typeof decoded !== 'object') { + return []; + } + if (decoded.tool_calls) { + return parseToolCallList(decoded.tool_calls); + } + const one = parseToolCallItem(decoded); + return one ? [one] : []; +} + +function parseToolCallList(v) { + if (!Array.isArray(v)) { + return []; + } + const out = []; + for (const item of v) { + if (!item || typeof item !== 'object') { + continue; + } + const one = parseToolCallItem(item); + if (one) { + out.push(one); + } + } + return out; +} + +function parseToolCallItem(m) { + let name = toStringSafe(m.name); + let inputRaw = m.input; + let hasInput = Object.prototype.hasOwnProperty.call(m, 'input'); + const fn = m.function && typeof m.function === 'object' ? m.function : null; + if (fn) { + if (!name) { + name = toStringSafe(fn.name); + } + if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) { + inputRaw = fn.arguments; + hasInput = true; + } + } + if (!hasInput) { + for (const k of ['arguments', 'args', 'parameters', 'params']) { + if (Object.prototype.hasOwnProperty.call(m, k)) { + inputRaw = m[k]; + hasInput = true; + break; + } + } + } + if (!name) { + return null; + } + return { + name, + input: parseToolCallInput(inputRaw), + }; +} + +function parseToolCallInput(v) { + if (v == null) { + return {}; + } + if (typeof v === 'string') { + const raw = toStringSafe(v); + if (!raw) { + return {}; + } + try { + const parsed = JSON.parse(raw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed; + } + } catch (_err) { + return { _raw: raw }; + } + return {}; + } + if (typeof v === 'object' && !Array.isArray(v)) { + return v; + } + try { + const parsed = JSON.parse(JSON.stringify(v)); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed; + } + } catch (_err) { + return {}; + } + return {}; +} + +function formatOpenAIStreamToolCalls(calls) { + if (!Array.isArray(calls) || calls.length === 0) { + return []; + } + return calls.map((c, idx) => ({ + index: idx, + id: `call_${newCallID()}`, + type: 'function', + function: { + name: c.name, + arguments: JSON.stringify(c.input || {}), + }, + })); +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + +function toStringSafe(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return toStringSafe(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + formatOpenAIStreamToolCalls, +}; diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index e9d4bde..d78bff3 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -3,15 +3,10 @@ package openai import ( "bufio" "context" - "crypto/rand" - "crypto/subtle" - "encoding/hex" "encoding/json" "fmt" "io" "net/http" - "os" - "strconv" "strings" "sync" "time" @@ -39,17 +34,6 @@ type streamLease struct { ExpiresAt time.Time } -type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool -} - -type toolStreamEvent struct { - Content string - ToolCalls []util.ParsedToolCall -} - func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/v1/models", h.ListModels) r.Post("/v1/chat/completions", h.ChatCompletions) @@ -140,142 +124,6 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) } -func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Request) { - if !config.IsVercel() { - http.NotFound(w, r) - return - } - h.sweepExpiredStreamLeases() - internalSecret := vercelInternalSecret() - internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token")) - if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 { - writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request") - return - } - - a, err := h.Auth.Determine(r) - if err != nil { - status := http.StatusUnauthorized - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeOpenAIError(w, status, err.Error()) - return - } - leased := false - defer func() { - if !leased { - h.Auth.Release(a) - } - }() - r = r.WithContext(auth.WithAuth(r.Context(), a)) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeOpenAIError(w, http.StatusBadRequest, "invalid json") - return - } - if !toBool(req["stream"]) { - writeOpenAIError(w, http.StatusBadRequest, "stream must be true") - return - } - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - writeOpenAIError(w, http.StatusBadRequest, "tools are not supported by vercel stream prepare") - return - } - - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") - return - } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) - if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - - messages := normalizeMessages(messagesRaw) - finalPrompt := util.MessagesPrepare(messages) - - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - if a.UseConfigToken { - writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") - } else { - writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") - } - return - } - powHeader, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") - return - } - if strings.TrimSpace(a.DeepSeekToken) == "" { - writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") - return - } - - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - leaseID := h.holdStreamLease(a) - if leaseID == "" { - writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") - return - } - leased = true - writeJSON(w, http.StatusOK, map[string]any{ - "session_id": sessionID, - "lease_id": leaseID, - "model": model, - "final_prompt": finalPrompt, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - "deepseek_token": a.DeepSeekToken, - "pow_header": powHeader, - "payload": payload, - }) -} - -func (h *Handler) handleVercelStreamRelease(w http.ResponseWriter, r *http.Request) { - if !config.IsVercel() { - http.NotFound(w, r) - return - } - h.sweepExpiredStreamLeases() - internalSecret := vercelInternalSecret() - internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token")) - if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 { - writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request") - return - } - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeOpenAIError(w, http.StatusBadRequest, "invalid json") - return - } - leaseID, _ := req["lease_id"].(string) - leaseID = strings.TrimSpace(leaseID) - if leaseID == "" { - writeOpenAIError(w, http.StatusBadRequest, "lease_id is required") - return - } - if !h.releaseStreamLease(leaseID) { - writeOpenAIError(w, http.StatusNotFound, "stream lease not found") - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": true}) -} - func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -700,367 +548,3 @@ func openAIErrorType(status int) string { return "invalid_request_error" } } - -func isVercelStreamPrepareRequest(r *http.Request) bool { - if r == nil { - return false - } - return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" -} - -func isVercelStreamReleaseRequest(r *http.Request) bool { - if r == nil { - return false - } - return strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1" -} - -func vercelInternalSecret() string { - if v := strings.TrimSpace(os.Getenv("DS2API_VERCEL_INTERNAL_SECRET")); v != "" { - return v - } - if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" { - return v - } - return "admin" -} - -func shouldEmitBufferedToolProbeContent(buffered string) bool { - trimmed := strings.TrimSpace(buffered) - if trimmed == "" { - return false - } - normalized := normalizeToolProbePrefix(trimmed) - if normalized == "" { - return false - } - first := normalized[0] - switch first { - case '{', '[', '`': - lower := strings.ToLower(normalized) - if strings.Contains(lower, "tool_calls") { - return false - } - // Keep a short hold window for JSON-ish starts to avoid leaking tool JSON. - if len([]rune(normalized)) < 20 { - return false - } - return true - default: - // Natural language starts can be streamed immediately. - return true - } -} - -func normalizeToolProbePrefix(s string) string { - t := strings.TrimSpace(s) - if strings.HasPrefix(t, "```") { - t = strings.TrimPrefix(t, "```") - t = strings.TrimSpace(t) - t = strings.TrimPrefix(strings.ToLower(t), "json") - t = strings.TrimSpace(t) - } - return t -} - -func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { - if state == nil || chunk == "" { - return nil - } - state.pending.WriteString(chunk) - events := make([]toolStreamEvent, 0, 2) - - for { - if state.capturing { - if state.pending.Len() > 0 { - state.capture.WriteString(state.pending.String()) - state.pending.Reset() - } - prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames) - if !ready { - break - } - state.capture.Reset() - state.capturing = false - if prefix != "" { - events = append(events, toolStreamEvent{Content: prefix}) - } - if len(calls) > 0 { - events = append(events, toolStreamEvent{ToolCalls: calls}) - } - if suffix != "" { - state.pending.WriteString(suffix) - } - continue - } - - pending := state.pending.String() - if pending == "" { - break - } - start := findToolSegmentStart(pending) - if start >= 0 { - prefix := pending[:start] - if prefix != "" { - events = append(events, toolStreamEvent{Content: prefix}) - } - state.pending.Reset() - state.capture.WriteString(pending[start:]) - state.capturing = true - continue - } - - safe, hold := splitSafeContentForToolDetection(pending) - if safe == "" { - break - } - state.pending.Reset() - state.pending.WriteString(hold) - events = append(events, toolStreamEvent{Content: safe}) - } - - return events -} - -func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent { - if state == nil { - return nil - } - events := processToolSieveChunk(state, "", toolNames) - if state.capturing { - raw := state.capture.String() - state.capture.Reset() - state.capturing = false - if raw != "" { - events = append(events, toolStreamEvent{Content: raw}) - } - } - if state.pending.Len() > 0 { - events = append(events, toolStreamEvent{Content: state.pending.String()}) - state.pending.Reset() - } - return events -} - -func splitSafeContentForToolDetection(s string) (safe, hold string) { - if s == "" { - return "", "" - } - suspiciousStart := findSuspiciousPrefixStart(s) - if suspiciousStart < 0 { - return s, "" - } - if suspiciousStart > 0 { - return s[:suspiciousStart], s[suspiciousStart:] - } - runes := []rune(s) - const maxHold = 128 - if len(runes) <= maxHold { - return "", s - } - return string(runes[:len(runes)-maxHold]), string(runes[len(runes)-maxHold:]) -} - -func findSuspiciousPrefixStart(s string) int { - start := -1 - indices := []int{ - strings.LastIndex(s, "{"), - strings.LastIndex(s, "["), - strings.LastIndex(s, "```"), - } - for _, idx := range indices { - if idx > start { - start = idx - } - } - return start -} - -func findToolSegmentStart(s string) int { - if s == "" { - return -1 - } - lower := strings.ToLower(s) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - return -1 - } - if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 { - return start - } - return keyIdx -} - -func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { - if captured == "" { - return "", nil, "", false - } - lower := strings.ToLower(captured) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - if len([]rune(captured)) >= 256 { - return captured, nil, "", true - } - return "", nil, "", false - } - start := strings.LastIndex(captured[:keyIdx], "{") - if start < 0 { - if len([]rune(captured)) >= 512 { - return captured, nil, "", true - } - return "", nil, "", false - } - obj, end, ok := extractJSONObjectFrom(captured, start) - if !ok { - if len([]rune(captured)) >= 4096 { - return captured, nil, "", true - } - return "", nil, "", false - } - parsed := util.ParseToolCalls(obj, toolNames) - if len(parsed) == 0 { - return captured[:end], nil, captured[end:], true - } - return captured[:start], parsed, captured[end:], true -} - -func extractJSONObjectFrom(text string, start int) (string, int, bool) { - if start < 0 || start >= len(text) || text[start] != '{' { - return "", 0, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' { - depth++ - continue - } - if ch == '}' { - depth-- - if depth == 0 { - end := i + 1 - return text[start:end], end, true - } - } - } - return "", 0, false -} - -func (h *Handler) holdStreamLease(a *auth.RequestAuth) string { - if a == nil { - return "" - } - now := time.Now() - ttl := streamLeaseTTL() - if ttl <= 0 { - ttl = 15 * time.Minute - } - - h.leaseMu.Lock() - expired := h.popExpiredLeasesLocked(now) - if h.streamLeases == nil { - h.streamLeases = make(map[string]streamLease) - } - leaseID := newLeaseID() - h.streamLeases[leaseID] = streamLease{ - Auth: a, - ExpiresAt: now.Add(ttl), - } - h.leaseMu.Unlock() - h.releaseExpiredAuths(expired) - return leaseID -} - -func (h *Handler) releaseStreamLease(leaseID string) bool { - leaseID = strings.TrimSpace(leaseID) - if leaseID == "" { - return false - } - - h.leaseMu.Lock() - expired := h.popExpiredLeasesLocked(time.Now()) - lease, ok := h.streamLeases[leaseID] - if ok { - delete(h.streamLeases, leaseID) - } - h.leaseMu.Unlock() - h.releaseExpiredAuths(expired) - - if !ok { - return false - } - if h.Auth != nil { - h.Auth.Release(lease.Auth) - } - return true -} - -func (h *Handler) popExpiredLeasesLocked(now time.Time) []*auth.RequestAuth { - if len(h.streamLeases) == 0 { - return nil - } - expired := make([]*auth.RequestAuth, 0) - for leaseID, lease := range h.streamLeases { - if now.After(lease.ExpiresAt) { - delete(h.streamLeases, leaseID) - expired = append(expired, lease.Auth) - } - } - return expired -} - -func (h *Handler) releaseExpiredAuths(expired []*auth.RequestAuth) { - if h.Auth == nil || len(expired) == 0 { - return - } - for _, a := range expired { - h.Auth.Release(a) - } -} - -func (h *Handler) sweepExpiredStreamLeases() { - h.leaseMu.Lock() - expired := h.popExpiredLeasesLocked(time.Now()) - h.leaseMu.Unlock() - h.releaseExpiredAuths(expired) -} - -func streamLeaseTTL() time.Duration { - raw := strings.TrimSpace(os.Getenv("DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS")) - if raw == "" { - return 15 * time.Minute - } - seconds, err := strconv.Atoi(raw) - if err != nil || seconds <= 0 { - return 15 * time.Minute - } - return time.Duration(seconds) * time.Second -} - -func newLeaseID() string { - buf := make([]byte, 16) - if _, err := rand.Read(buf); err == nil { - return hex.EncodeToString(buf) - } - return fmt.Sprintf("lease-%d", time.Now().UnixNano()) -} diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go new file mode 100644 index 0000000..3fd7262 --- /dev/null +++ b/internal/adapter/openai/tool_sieve.go @@ -0,0 +1,236 @@ +package openai + +import ( + "strings" + + "ds2api/internal/util" +) + +type toolStreamSieveState struct { + pending strings.Builder + capture strings.Builder + capturing bool +} + +type toolStreamEvent struct { + Content string + ToolCalls []util.ParsedToolCall +} + +func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { + if state == nil { + return nil + } + if chunk != "" { + state.pending.WriteString(chunk) + } + events := make([]toolStreamEvent, 0, 2) + + for { + if state.capturing { + if state.pending.Len() > 0 { + state.capture.WriteString(state.pending.String()) + state.pending.Reset() + } + prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames) + if !ready { + break + } + state.capture.Reset() + state.capturing = false + if prefix != "" { + events = append(events, toolStreamEvent{Content: prefix}) + } + if len(calls) > 0 { + events = append(events, toolStreamEvent{ToolCalls: calls}) + } + if suffix != "" { + state.pending.WriteString(suffix) + } + continue + } + + pending := state.pending.String() + if pending == "" { + break + } + start := findToolSegmentStart(pending) + if start >= 0 { + prefix := pending[:start] + if prefix != "" { + events = append(events, toolStreamEvent{Content: prefix}) + } + state.pending.Reset() + state.capture.WriteString(pending[start:]) + state.capturing = true + continue + } + + safe, hold := splitSafeContentForToolDetection(pending) + if safe == "" { + break + } + state.pending.Reset() + state.pending.WriteString(hold) + events = append(events, toolStreamEvent{Content: safe}) + } + + return events +} + +func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent { + if state == nil { + return nil + } + events := processToolSieveChunk(state, "", toolNames) + if state.capturing { + consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames) + if ready { + if consumedPrefix != "" { + events = append(events, toolStreamEvent{Content: consumedPrefix}) + } + if len(consumedCalls) > 0 { + events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) + } + if consumedSuffix != "" { + events = append(events, toolStreamEvent{Content: consumedSuffix}) + } + } else { + raw := state.capture.String() + if raw != "" { + events = append(events, toolStreamEvent{Content: raw}) + } + } + state.capture.Reset() + state.capturing = false + } + if state.pending.Len() > 0 { + events = append(events, toolStreamEvent{Content: state.pending.String()}) + state.pending.Reset() + } + return events +} + +func splitSafeContentForToolDetection(s string) (safe, hold string) { + if s == "" { + return "", "" + } + suspiciousStart := findSuspiciousPrefixStart(s) + if suspiciousStart < 0 { + return s, "" + } + if suspiciousStart > 0 { + return s[:suspiciousStart], s[suspiciousStart:] + } + runes := []rune(s) + const maxHold = 128 + if len(runes) <= maxHold { + return "", s + } + return string(runes[:len(runes)-maxHold]), string(runes[len(runes)-maxHold:]) +} + +func findSuspiciousPrefixStart(s string) int { + start := -1 + indices := []int{ + strings.LastIndex(s, "{"), + strings.LastIndex(s, "["), + strings.LastIndex(s, "```"), + } + for _, idx := range indices { + if idx > start { + start = idx + } + } + return start +} + +func findToolSegmentStart(s string) int { + if s == "" { + return -1 + } + lower := strings.ToLower(s) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return -1 + } + if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 { + return start + } + return keyIdx +} + +func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { + if captured == "" { + return "", nil, "", false + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + if len([]rune(captured)) >= 256 { + return captured, nil, "", true + } + return "", nil, "", false + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 { + if len([]rune(captured)) >= 512 { + return captured, nil, "", true + } + return "", nil, "", false + } + obj, end, ok := extractJSONObjectFrom(captured, start) + if !ok { + if len([]rune(captured)) >= 4096 { + return captured, nil, "", true + } + return "", nil, "", false + } + parsed := util.ParseToolCalls(obj, toolNames) + if len(parsed) == 0 { + return captured[:end], nil, captured[end:], true + } + return captured[:start], parsed, captured[end:], true +} + +func extractJSONObjectFrom(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '{' { + return "", 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + end := i + 1 + return text[start:end], end, true + } + } + } + return "", 0, false +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go new file mode 100644 index 0000000..3e75f47 --- /dev/null +++ b/internal/adapter/openai/vercel_stream.go @@ -0,0 +1,277 @@ +package openai + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "os" + "strconv" + "strings" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Request) { + if !config.IsVercel() { + http.NotFound(w, r) + return + } + h.sweepExpiredStreamLeases() + internalSecret := vercelInternalSecret() + internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token")) + if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request") + return + } + + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, err.Error()) + return + } + leased := false + defer func() { + if !leased { + h.Auth.Release(a) + } + }() + r = r.WithContext(auth.WithAuth(r.Context(), a)) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + if !toBool(req["stream"]) { + writeOpenAIError(w, http.StatusBadRequest, "stream must be true") + return + } + if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { + writeOpenAIError(w, http.StatusBadRequest, "tools are not supported by vercel stream prepare") + return + } + + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if model == "" || len(messagesRaw) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + return + } + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) + if !ok { + writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) + return + } + + messages := normalizeMessages(messagesRaw) + finalPrompt := util.MessagesPrepare(messages) + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + } + return + } + powHeader, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + if strings.TrimSpace(a.DeepSeekToken) == "" { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + return + } + + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": finalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": thinkingEnabled, + "search_enabled": searchEnabled, + } + leaseID := h.holdStreamLease(a) + if leaseID == "" { + writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") + return + } + leased = true + writeJSON(w, http.StatusOK, map[string]any{ + "session_id": sessionID, + "lease_id": leaseID, + "model": model, + "final_prompt": finalPrompt, + "thinking_enabled": thinkingEnabled, + "search_enabled": searchEnabled, + "deepseek_token": a.DeepSeekToken, + "pow_header": powHeader, + "payload": payload, + }) +} + +func (h *Handler) handleVercelStreamRelease(w http.ResponseWriter, r *http.Request) { + if !config.IsVercel() { + http.NotFound(w, r) + return + } + h.sweepExpiredStreamLeases() + internalSecret := vercelInternalSecret() + internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token")) + if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request") + return + } + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + leaseID, _ := req["lease_id"].(string) + leaseID = strings.TrimSpace(leaseID) + if leaseID == "" { + writeOpenAIError(w, http.StatusBadRequest, "lease_id is required") + return + } + if !h.releaseStreamLease(leaseID) { + writeOpenAIError(w, http.StatusNotFound, "stream lease not found") + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true}) +} + +func isVercelStreamPrepareRequest(r *http.Request) bool { + if r == nil { + return false + } + return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" +} + +func isVercelStreamReleaseRequest(r *http.Request) bool { + if r == nil { + return false + } + return strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1" +} + +func vercelInternalSecret() string { + if v := strings.TrimSpace(os.Getenv("DS2API_VERCEL_INTERNAL_SECRET")); v != "" { + return v + } + if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" { + return v + } + return "admin" +} + +func (h *Handler) holdStreamLease(a *auth.RequestAuth) string { + if a == nil { + return "" + } + now := time.Now() + ttl := streamLeaseTTL() + if ttl <= 0 { + ttl = 15 * time.Minute + } + + h.leaseMu.Lock() + expired := h.popExpiredLeasesLocked(now) + if h.streamLeases == nil { + h.streamLeases = make(map[string]streamLease) + } + leaseID := newLeaseID() + h.streamLeases[leaseID] = streamLease{ + Auth: a, + ExpiresAt: now.Add(ttl), + } + h.leaseMu.Unlock() + h.releaseExpiredAuths(expired) + return leaseID +} + +func (h *Handler) releaseStreamLease(leaseID string) bool { + leaseID = strings.TrimSpace(leaseID) + if leaseID == "" { + return false + } + + h.leaseMu.Lock() + expired := h.popExpiredLeasesLocked(time.Now()) + lease, ok := h.streamLeases[leaseID] + if ok { + delete(h.streamLeases, leaseID) + } + h.leaseMu.Unlock() + h.releaseExpiredAuths(expired) + + if !ok { + return false + } + if h.Auth != nil { + h.Auth.Release(lease.Auth) + } + return true +} + +func (h *Handler) popExpiredLeasesLocked(now time.Time) []*auth.RequestAuth { + if len(h.streamLeases) == 0 { + return nil + } + expired := make([]*auth.RequestAuth, 0) + for leaseID, lease := range h.streamLeases { + if now.After(lease.ExpiresAt) { + delete(h.streamLeases, leaseID) + expired = append(expired, lease.Auth) + } + } + return expired +} + +func (h *Handler) releaseExpiredAuths(expired []*auth.RequestAuth) { + if h.Auth == nil || len(expired) == 0 { + return + } + for _, a := range expired { + h.Auth.Release(a) + } +} + +func (h *Handler) sweepExpiredStreamLeases() { + h.leaseMu.Lock() + expired := h.popExpiredLeasesLocked(time.Now()) + h.leaseMu.Unlock() + h.releaseExpiredAuths(expired) +} + +func streamLeaseTTL() time.Duration { + raw := strings.TrimSpace(os.Getenv("DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS")) + if raw == "" { + return 15 * time.Minute + } + seconds, err := strconv.Atoi(raw) + if err != nil || seconds <= 0 { + return 15 * time.Minute + } + return time.Duration(seconds) * time.Second +} + +func newLeaseID() string { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err == nil { + return hex.EncodeToString(buf) + } + return fmt.Sprintf("lease-%d", time.Now().UnixNano()) +} diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 1ca4378..9d6151e 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -1,29 +1,11 @@ package admin import ( - "bufio" - "bytes" - "context" - "crypto/md5" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "sort" - "strconv" - "strings" - "time" - "github.com/go-chi/chi/v5" "ds2api/internal/account" - authn "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/deepseek" - "ds2api/internal/sse" ) type Handler struct { @@ -33,7 +15,6 @@ type Handler struct { } func RegisterRoutes(r chi.Router, h *Handler) { - r.Post("/login", h.login) r.Get("/verify", h.verify) r.Group(func(pr chi.Router) { @@ -56,848 +37,3 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Get("/export", h.exportConfig) }) } - -func (h *Handler) requireAdmin(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := authn.VerifyAdminRequest(r); err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) - return - } - next.ServeHTTP(w, r) - }) -} - -func (h *Handler) login(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - adminKey, _ := req["admin_key"].(string) - expireHours := intFrom(req["expire_hours"]) - if expireHours <= 0 { - expireHours = 24 - } - if adminKey != authn.AdminKey() { - writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"}) - return - } - token, err := authn.CreateJWT(expireHours) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600}) -} - -func (h *Handler) verify(w http.ResponseWriter, r *http.Request) { - header := strings.TrimSpace(r.Header.Get("Authorization")) - if !strings.HasPrefix(strings.ToLower(header), "bearer ") { - writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "No credentials provided"}) - return - } - token := strings.TrimSpace(header[7:]) - payload, err := authn.VerifyJWT(token) - if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) - return - } - exp, _ := payload["exp"].(float64) - remaining := int64(exp) - time.Now().Unix() - if remaining < 0 { - remaining = 0 - } - writeJSON(w, http.StatusOK, map[string]any{"valid": true, "expires_at": int64(exp), "remaining_seconds": remaining}) -} - -func (h *Handler) getVercelConfig(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, map[string]any{ - "has_token": strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) != "", - "project_id": strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")), - "team_id": nilIfEmpty(strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID"))), - }) -} - -func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { - snap := h.Store.Snapshot() - safe := map[string]any{ - "keys": snap.Keys, - "accounts": []map[string]any{}, - "claude_mapping": func() map[string]string { - if len(snap.ClaudeMapping) > 0 { - return snap.ClaudeMapping - } - return snap.ClaudeModelMap - }(), - } - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, acc := range snap.Accounts { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - accounts = append(accounts, map[string]any{ - "email": acc.Email, - "mobile": acc.Mobile, - "has_password": strings.TrimSpace(acc.Password) != "", - "has_token": token != "", - "token_preview": preview, - }) - } - safe["accounts"] = accounts - writeJSON(w, http.StatusOK, safe) -} - -func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) - return - } - old := h.Store.Snapshot() - err := h.Store.Update(func(c *config.Config) error { - if keys, ok := toStringSlice(req["keys"]); ok { - c.Keys = keys - } - if accountsRaw, ok := req["accounts"].([]any); ok { - existing := map[string]config.Account{} - for _, a := range old.Accounts { - existing[a.Identifier()] = a - } - accounts := make([]config.Account, 0, len(accountsRaw)) - for _, item := range accountsRaw { - m, ok := item.(map[string]any) - if !ok { - continue - } - acc := toAccount(m) - id := acc.Identifier() - if prev, ok := existing[id]; ok { - if strings.TrimSpace(acc.Password) == "" { - acc.Password = prev.Password - } - if strings.TrimSpace(acc.Token) == "" { - acc.Token = prev.Token - } - } - accounts = append(accounts, acc) - } - c.Accounts = accounts - } - if m, ok := req["claude_mapping"].(map[string]any); ok { - newMap := map[string]string{} - for k, v := range m { - newMap[k] = fmt.Sprintf("%v", v) - } - c.ClaudeMapping = newMap - } - return nil - }) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"}) -} - -func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - key, _ := req["key"].(string) - key = strings.TrimSpace(key) - if key == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"}) - return - } - err := h.Store.Update(func(c *config.Config) error { - for _, k := range c.Keys { - if k == key { - return fmt.Errorf("Key 已存在") - } - } - c.Keys = append(c.Keys, key) - return nil - }) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) -} - -func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) { - key := chi.URLParam(r, "key") - err := h.Store.Update(func(c *config.Config) error { - idx := -1 - for i, k := range c.Keys { - if k == key { - idx = i - break - } - } - if idx < 0 { - return fmt.Errorf("Key 不存在") - } - c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...) - return nil - }) - if err != nil { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) -} - -func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { - page := intFromQuery(r, "page", 1) - pageSize := intFromQuery(r, "page_size", 10) - if page < 1 { - page = 1 - } - if pageSize < 1 { - pageSize = 1 - } - if pageSize > 100 { - pageSize = 100 - } - accounts := h.Store.Snapshot().Accounts - total := len(accounts) - reverseAccounts(accounts) - totalPages := 1 - if total > 0 { - totalPages = (total + pageSize - 1) / pageSize - } - start := (page - 1) * pageSize - if start > total { - start = total - } - end := start + pageSize - if end > total { - end = total - } - items := make([]map[string]any, 0, end-start) - for _, acc := range accounts[start:end] { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview}) - } - writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) -} - -func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - acc := toAccount(req) - if acc.Identifier() == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) - return - } - err := h.Store.Update(func(c *config.Config) error { - for _, a := range c.Accounts { - if acc.Email != "" && a.Email == acc.Email { - return fmt.Errorf("邮箱已存在") - } - if acc.Mobile != "" && a.Mobile == acc.Mobile { - return fmt.Errorf("手机号已存在") - } - } - c.Accounts = append(c.Accounts, acc) - return nil - }) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - identifier := chi.URLParam(r, "identifier") - err := h.Store.Update(func(c *config.Config) error { - idx := -1 - for i, a := range c.Accounts { - if a.Email == identifier || a.Mobile == identifier { - idx = i - break - } - } - if idx < 0 { - return fmt.Errorf("账号不存在") - } - c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) - return nil - }) - if err != nil { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, h.Pool.Status()) -} - -func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - identifier, _ := req["identifier"].(string) - if strings.TrimSpace(identifier) == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"}) - return - } - acc, ok := h.Store.FindAccount(identifier) - if !ok { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"}) - return - } - model, _ := req["model"].(string) - if model == "" { - model = "deepseek-chat" - } - message, _ := req["message"].(string) - result := h.testAccount(r.Context(), acc, model, message) - writeJSON(w, http.StatusOK, result) -} - -func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - model, _ := req["model"].(string) - if model == "" { - model = "deepseek-chat" - } - accounts := h.Store.Snapshot().Accounts - if len(accounts) == 0 { - writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}}) - return - } - results := make([]map[string]any, 0, len(accounts)) - success := 0 - for _, acc := range accounts { - res := h.testAccount(r.Context(), acc, model, "") - if ok, _ := res["success"].(bool); ok { - success++ - } - results = append(results, res) - time.Sleep(time.Second) - } - writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results}) -} - -func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any { - start := time.Now() - result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model} - token := strings.TrimSpace(acc.Token) - if token == "" { - newToken, err := h.DS.Login(ctx, acc) - if err != nil { - result["message"] = "登录失败: " + err.Error() - return result - } - token = newToken - _ = h.Store.UpdateAccountToken(acc.Identifier(), token) - } - authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token} - sessionID, err := h.DS.CreateSession(ctx, authCtx, 1) - if err != nil { - newToken, loginErr := h.DS.Login(ctx, acc) - if loginErr != nil { - result["message"] = "创建会话失败: " + err.Error() - return result - } - token = newToken - authCtx.DeepSeekToken = token - _ = h.Store.UpdateAccountToken(acc.Identifier(), token) - sessionID, err = h.DS.CreateSession(ctx, authCtx, 1) - if err != nil { - result["message"] = "创建会话失败: " + err.Error() - return result - } - } - if strings.TrimSpace(message) == "" { - result["success"] = true - result["message"] = "API 测试成功(仅会话创建)" - result["response_time"] = int(time.Since(start).Milliseconds()) - return result - } - thinking, search, ok := config.GetModelConfig(model) - if !ok { - thinking, search = false, false - } - pow, err := h.DS.GetPow(ctx, authCtx, 1) - if err != nil { - result["message"] = "获取 PoW 失败: " + err.Error() - return result - } - payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search} - resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1) - if err != nil { - result["message"] = "请求失败: " + err.Error() - return result - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode) - return result - } - text := strings.Builder{} - think := strings.Builder{} - currentType := "text" - if thinking { - currentType = "thinking" - } - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - chunk, done, parsed := sse.ParseDeepSeekSSELine(scanner.Bytes()) - if !parsed { - continue - } - if done { - break - } - parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinking, currentType) - currentType = newType - if finished { - break - } - for _, p := range parts { - if p.Type == "thinking" { - think.WriteString(p.Text) - } else { - text.WriteString(p.Text) - } - } - } - result["success"] = true - result["response_time"] = int(time.Since(start).Milliseconds()) - if text.Len() > 0 { - result["message"] = text.String() - } else { - result["message"] = "(无回复内容)" - } - if think.Len() > 0 { - result["thinking"] = think.String() - } - return result -} - -func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"}) - return - } - importedKeys, importedAccounts := 0, 0 - err := h.Store.Update(func(c *config.Config) error { - if keys, ok := req["keys"].([]any); ok { - existing := map[string]bool{} - for _, k := range c.Keys { - existing[k] = true - } - for _, k := range keys { - key := strings.TrimSpace(fmt.Sprintf("%v", k)) - if key == "" || existing[key] { - continue - } - c.Keys = append(c.Keys, key) - existing[key] = true - importedKeys++ - } - } - if accounts, ok := req["accounts"].([]any); ok { - existing := map[string]bool{} - for _, a := range c.Accounts { - existing[a.Identifier()] = true - } - for _, item := range accounts { - m, ok := item.(map[string]any) - if !ok { - continue - } - acc := toAccount(m) - id := acc.Identifier() - if id == "" || existing[id] { - continue - } - c.Accounts = append(c.Accounts, acc) - existing[id] = true - importedAccounts++ - } - } - return nil - }) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts}) -} - -func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - model, _ := req["model"].(string) - message, _ := req["message"].(string) - apiKey, _ := req["api_key"].(string) - if model == "" { - model = "deepseek-chat" - } - if message == "" { - message = "你好" - } - if apiKey == "" { - keys := h.Store.Snapshot().Keys - if len(keys) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"}) - return - } - apiKey = keys[0] - } - host := r.Host - scheme := "http" - if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") { - scheme = "https" - } - payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false} - b, _ := json.Marshal(payload) - request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b)) - request.Header.Set("Authorization", "Bearer "+apiKey) - request.Header.Set("Content-Type", "application/json") - resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request) - if err != nil { - writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()}) - return - } - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode == http.StatusOK { - var parsed any - _ = json.Unmarshal(body, &parsed) - writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)}) -} - -func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) - return - } - vercelToken, _ := req["vercel_token"].(string) - projectID, _ := req["project_id"].(string) - teamID, _ := req["team_id"].(string) - autoValidate := true - if v, ok := req["auto_validate"].(bool); ok { - autoValidate = v - } - saveCreds := true - if v, ok := req["save_credentials"].(bool); ok { - saveCreds = v - } - usePreconfig := vercelToken == "__USE_PRECONFIG__" || strings.TrimSpace(vercelToken) == "" - if usePreconfig { - vercelToken = strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) - } - if strings.TrimSpace(projectID) == "" { - projectID = strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")) - } - if strings.TrimSpace(teamID) == "" { - teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID")) - } - if vercelToken == "" || projectID == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"}) - return - } - validated, failed := 0, []string{} - if autoValidate { - for _, acc := range h.Store.Snapshot().Accounts { - if strings.TrimSpace(acc.Token) != "" { - continue - } - token, err := h.DS.Login(r.Context(), acc) - if err != nil { - failed = append(failed, acc.Identifier()) - } else { - validated++ - _ = h.Store.UpdateAccountToken(acc.Identifier(), token) - } - time.Sleep(500 * time.Millisecond) - } - } - - cfgJSON, _, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON)) - client := &http.Client{Timeout: 30 * time.Second} - params := url.Values{} - if teamID != "" { - params.Set("teamId", teamID) - } - headers := map[string]string{"Authorization": "Bearer " + vercelToken} - envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil) - if err != nil || status != http.StatusOK { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) - return - } - envs, _ := envResp["envs"].([]any) - existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON") - if existingEnvID != "" { - _, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64}) - } else { - _, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}}) - } - if err != nil || (status != http.StatusOK && status != http.StatusCreated) { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) - return - } - savedCreds := []string{} - if saveCreds && !usePreconfig { - creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}} - if teamID != "" { - creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID}) - } - for _, kv := range creds { - id := findEnvID(envs, kv[0]) - if id != "" { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]}) - } else { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}}) - } - if status == http.StatusOK || status == http.StatusCreated { - savedCreds = append(savedCreds, kv[0]) - } - } - } - projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) - manual := true - deployURL := "" - if status == http.StatusOK { - if link, ok := projectResp["link"].(map[string]any); ok { - if linkType, _ := link["type"].(string); linkType == "github" { - repoID := intFrom(link["repoId"]) - ref, _ := link["productionBranch"].(string) - if ref == "" { - ref = "main" - } - depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}}) - if depStatus == http.StatusOK || depStatus == http.StatusCreated { - deployURL, _ = depResp["url"].(string) - manual = false - } - } - } - } - _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) - result := map[string]any{"success": true, "validated_accounts": validated} - if manual { - result["message"] = "配置已同步到 Vercel,请手动触发重新部署" - result["manual_deploy_required"] = true - } else { - result["message"] = "配置已同步,正在重新部署..." - result["deployment_url"] = deployURL - } - if len(failed) > 0 { - result["failed_accounts"] = failed - } - if len(savedCreds) > 0 { - result["saved_credentials"] = savedCreds - } - writeJSON(w, http.StatusOK, result) -} - -func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) { - snap := h.Store.Snapshot() - current := h.computeSyncHash() - synced := snap.VercelSyncHash != "" && snap.VercelSyncHash == current - writeJSON(w, http.StatusOK, map[string]any{"synced": synced, "last_sync_time": nilIfZero(snap.VercelSyncTime), "has_synced_before": snap.VercelSyncHash != ""}) -} - -func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { - jsonStr, b64, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64}) -} - -func (h *Handler) computeSyncHash() string { - snap := h.Store.Snapshot() - syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}} - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, a := range snap.Accounts { - m := map[string]any{} - if a.Email != "" { - m["email"] = a.Email - } - if a.Mobile != "" { - m["mobile"] = a.Mobile - } - if a.Password != "" { - m["password"] = a.Password - } - accounts = append(accounts, m) - } - sort.Slice(accounts, func(i, j int) bool { - ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"]) - aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"]) - return ai < aj - }) - syncable["accounts"] = accounts - b, _ := json.Marshal(syncable) - sum := md5.Sum(b) - return fmt.Sprintf("%x", sum) -} - -func vercelRequest(ctx context.Context, client *http.Client, method, endpoint string, params url.Values, headers map[string]string, body any) (map[string]any, int, error) { - if len(params) > 0 { - endpoint += "?" + params.Encode() - } - var reader io.Reader - if body != nil { - b, _ := json.Marshal(body) - reader = bytes.NewReader(b) - } - req, err := http.NewRequestWithContext(ctx, method, endpoint, reader) - if err != nil { - return nil, 0, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) - if err != nil { - return nil, 0, err - } - defer resp.Body.Close() - b, _ := io.ReadAll(resp.Body) - parsed := map[string]any{} - _ = json.Unmarshal(b, &parsed) - if len(parsed) == 0 { - parsed["raw"] = string(b) - } - return parsed, resp.StatusCode, nil -} - -func findEnvID(envs []any, key string) string { - for _, item := range envs { - m, ok := item.(map[string]any) - if !ok { - continue - } - if k, _ := m["key"].(string); k == key { - id, _ := m["id"].(string) - return id - } - } - return "" -} - -func reverseAccounts(a []config.Account) { - for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { - a[i], a[j] = a[j], a[i] - } -} - -func intFromQuery(r *http.Request, key string, d int) int { - v := r.URL.Query().Get(key) - if v == "" { - return d - } - n, err := strconv.Atoi(v) - if err != nil { - return d - } - return n -} - -func intFrom(v any) int { - switch n := v.(type) { - case float64: - return int(n) - case int: - return n - case int64: - return int(n) - default: - return 0 - } -} - -func nilIfEmpty(s string) any { - if s == "" { - return nil - } - return s -} - -func nilIfZero(v int64) any { - if v == 0 { - return nil - } - return v -} - -func toStringSlice(v any) ([]string, bool) { - arr, ok := v.([]any) - if !ok { - return nil, false - } - out := make([]string, 0, len(arr)) - for _, item := range arr { - out = append(out, strings.TrimSpace(fmt.Sprintf("%v", item))) - } - return out, true -} - -func toAccount(m map[string]any) config.Account { - return config.Account{ - Email: fieldString(m, "email"), - Mobile: fieldString(m, "mobile"), - Password: fieldString(m, "password"), - Token: fieldString(m, "token"), - } -} - -func fieldString(m map[string]any, key string) string { - v, ok := m[key] - if !ok || v == nil { - return "" - } - return strings.TrimSpace(fmt.Sprintf("%v", v)) -} - -func statusOr(v int, d int) int { - if v == 0 { - return d - } - return v -} - -func writeJSON(w http.ResponseWriter, status int, payload any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(payload) -} diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts.go new file mode 100644 index 0000000..3dc43d1 --- /dev/null +++ b/internal/admin/handler_accounts.go @@ -0,0 +1,310 @@ +package admin + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + authn "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/sse" +) + +func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { + page := intFromQuery(r, "page", 1) + pageSize := intFromQuery(r, "page_size", 10) + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 1 + } + if pageSize > 100 { + pageSize = 100 + } + accounts := h.Store.Snapshot().Accounts + total := len(accounts) + reverseAccounts(accounts) + totalPages := 1 + if total > 0 { + totalPages = (total + pageSize - 1) / pageSize + } + start := (page - 1) * pageSize + if start > total { + start = total + } + end := start + pageSize + if end > total { + end = total + } + items := make([]map[string]any, 0, end-start) + for _, acc := range accounts[start:end] { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview}) + } + writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) +} + +func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + acc := toAccount(req) + if acc.Identifier() == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, a := range c.Accounts { + if acc.Email != "" && a.Email == acc.Email { + return fmt.Errorf("邮箱已存在") + } + if acc.Mobile != "" && a.Mobile == acc.Mobile { + return fmt.Errorf("手机号已存在") + } + } + c.Accounts = append(c.Accounts, acc) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { + identifier := chi.URLParam(r, "identifier") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, a := range c.Accounts { + if a.Email == identifier || a.Mobile == identifier { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("账号不存在") + } + c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, h.Pool.Status()) +} + +func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + identifier, _ := req["identifier"].(string) + if strings.TrimSpace(identifier) == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"}) + return + } + acc, ok := h.Store.FindAccount(identifier) + if !ok { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"}) + return + } + model, _ := req["model"].(string) + if model == "" { + model = "deepseek-chat" + } + message, _ := req["message"].(string) + result := h.testAccount(r.Context(), acc, model, message) + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + model, _ := req["model"].(string) + if model == "" { + model = "deepseek-chat" + } + accounts := h.Store.Snapshot().Accounts + if len(accounts) == 0 { + writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}}) + return + } + results := make([]map[string]any, 0, len(accounts)) + success := 0 + for _, acc := range accounts { + res := h.testAccount(r.Context(), acc, model, "") + if ok, _ := res["success"].(bool); ok { + success++ + } + results = append(results, res) + time.Sleep(time.Second) + } + writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results}) +} + +func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any { + start := time.Now() + result := map[string]any{"account": acc.Identifier(), "success": false, "response_time": 0, "message": "", "model": model} + token := strings.TrimSpace(acc.Token) + if token == "" { + newToken, err := h.DS.Login(ctx, acc) + if err != nil { + result["message"] = "登录失败: " + err.Error() + return result + } + token = newToken + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token} + sessionID, err := h.DS.CreateSession(ctx, authCtx, 1) + if err != nil { + newToken, loginErr := h.DS.Login(ctx, acc) + if loginErr != nil { + result["message"] = "创建会话失败: " + err.Error() + return result + } + token = newToken + authCtx.DeepSeekToken = token + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + sessionID, err = h.DS.CreateSession(ctx, authCtx, 1) + if err != nil { + result["message"] = "创建会话失败: " + err.Error() + return result + } + } + if strings.TrimSpace(message) == "" { + result["success"] = true + result["message"] = "API 测试成功(仅会话创建)" + result["response_time"] = int(time.Since(start).Milliseconds()) + return result + } + thinking, search, ok := config.GetModelConfig(model) + if !ok { + thinking, search = false, false + } + pow, err := h.DS.GetPow(ctx, authCtx, 1) + if err != nil { + result["message"] = "获取 PoW 失败: " + err.Error() + return result + } + payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search} + resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1) + if err != nil { + result["message"] = "请求失败: " + err.Error() + return result + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode) + return result + } + text := strings.Builder{} + think := strings.Builder{} + currentType := "text" + if thinking { + currentType = "thinking" + } + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + chunk, done, parsed := sse.ParseDeepSeekSSELine(scanner.Bytes()) + if !parsed { + continue + } + if done { + break + } + parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinking, currentType) + currentType = newType + if finished { + break + } + for _, p := range parts { + if p.Type == "thinking" { + think.WriteString(p.Text) + } else { + text.WriteString(p.Text) + } + } + } + result["success"] = true + result["response_time"] = int(time.Since(start).Milliseconds()) + if text.Len() > 0 { + result["message"] = text.String() + } else { + result["message"] = "(无回复内容)" + } + if think.Len() > 0 { + result["thinking"] = think.String() + } + return result +} + +func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + model, _ := req["model"].(string) + message, _ := req["message"].(string) + apiKey, _ := req["api_key"].(string) + if model == "" { + model = "deepseek-chat" + } + if message == "" { + message = "你好" + } + if apiKey == "" { + keys := h.Store.Snapshot().Keys + if len(keys) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"}) + return + } + apiKey = keys[0] + } + host := r.Host + scheme := "http" + if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") { + scheme = "https" + } + payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false} + b, _ := json.Marshal(payload) + request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b)) + request.Header.Set("Authorization", "Bearer "+apiKey) + request.Header.Set("Content-Type", "application/json") + resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request) + if err != nil { + writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()}) + return + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusOK { + var parsed any + _ = json.Unmarshal(body, &parsed) + writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)}) +} diff --git a/internal/admin/handler_auth.go b/internal/admin/handler_auth.go new file mode 100644 index 0000000..0d3ec1f --- /dev/null +++ b/internal/admin/handler_auth.go @@ -0,0 +1,69 @@ +package admin + +import ( + "encoding/json" + "net/http" + "os" + "strings" + "time" + + authn "ds2api/internal/auth" +) + +func (h *Handler) requireAdmin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := authn.VerifyAdminRequest(r); err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) + return + } + next.ServeHTTP(w, r) + }) +} + +func (h *Handler) login(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + adminKey, _ := req["admin_key"].(string) + expireHours := intFrom(req["expire_hours"]) + if expireHours <= 0 { + expireHours = 24 + } + if adminKey != authn.AdminKey() { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"}) + return + } + token, err := authn.CreateJWT(expireHours) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600}) +} + +func (h *Handler) verify(w http.ResponseWriter, r *http.Request) { + header := strings.TrimSpace(r.Header.Get("Authorization")) + if !strings.HasPrefix(strings.ToLower(header), "bearer ") { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "No credentials provided"}) + return + } + token := strings.TrimSpace(header[7:]) + payload, err := authn.VerifyJWT(token) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) + return + } + exp, _ := payload["exp"].(float64) + remaining := int64(exp) - time.Now().Unix() + if remaining < 0 { + remaining = 0 + } + writeJSON(w, http.StatusOK, map[string]any{"valid": true, "expires_at": int64(exp), "remaining_seconds": remaining}) +} + +func (h *Handler) getVercelConfig(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]any{ + "has_token": strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) != "", + "project_id": strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")), + "team_id": nilIfEmpty(strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID"))), + }) +} diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go new file mode 100644 index 0000000..7627602 --- /dev/null +++ b/internal/admin/handler_config.go @@ -0,0 +1,240 @@ +package admin + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "net/http" + "sort" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" +) + +func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + safe := map[string]any{ + "keys": snap.Keys, + "accounts": []map[string]any{}, + "claude_mapping": func() map[string]string { + if len(snap.ClaudeMapping) > 0 { + return snap.ClaudeMapping + } + return snap.ClaudeModelMap + }(), + } + accounts := make([]map[string]any, 0, len(snap.Accounts)) + for _, acc := range snap.Accounts { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + accounts = append(accounts, map[string]any{ + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": strings.TrimSpace(acc.Password) != "", + "has_token": token != "", + "token_preview": preview, + }) + } + safe["accounts"] = accounts + writeJSON(w, http.StatusOK, safe) +} + +func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + old := h.Store.Snapshot() + err := h.Store.Update(func(c *config.Config) error { + if keys, ok := toStringSlice(req["keys"]); ok { + c.Keys = keys + } + if accountsRaw, ok := req["accounts"].([]any); ok { + existing := map[string]config.Account{} + for _, a := range old.Accounts { + existing[a.Identifier()] = a + } + accounts := make([]config.Account, 0, len(accountsRaw)) + for _, item := range accountsRaw { + m, ok := item.(map[string]any) + if !ok { + continue + } + acc := toAccount(m) + id := acc.Identifier() + if prev, ok := existing[id]; ok { + if strings.TrimSpace(acc.Password) == "" { + acc.Password = prev.Password + } + if strings.TrimSpace(acc.Token) == "" { + acc.Token = prev.Token + } + } + accounts = append(accounts, acc) + } + c.Accounts = accounts + } + if m, ok := req["claude_mapping"].(map[string]any); ok { + newMap := map[string]string{} + for k, v := range m { + newMap[k] = fmt.Sprintf("%v", v) + } + c.ClaudeMapping = newMap + } + return nil + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"}) +} + +func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + key, _ := req["key"].(string) + key = strings.TrimSpace(key) + if key == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, k := range c.Keys { + if k == key { + return fmt.Errorf("Key 已存在") + } + } + c.Keys = append(c.Keys, key) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) +} + +func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) { + key := chi.URLParam(r, "key") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, k := range c.Keys { + if k == key { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("Key 不存在") + } + c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) +} + +func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"}) + return + } + importedKeys, importedAccounts := 0, 0 + err := h.Store.Update(func(c *config.Config) error { + if keys, ok := req["keys"].([]any); ok { + existing := map[string]bool{} + for _, k := range c.Keys { + existing[k] = true + } + for _, k := range keys { + key := strings.TrimSpace(fmt.Sprintf("%v", k)) + if key == "" || existing[key] { + continue + } + c.Keys = append(c.Keys, key) + existing[key] = true + importedKeys++ + } + } + if accounts, ok := req["accounts"].([]any); ok { + existing := map[string]bool{} + for _, a := range c.Accounts { + existing[a.Identifier()] = true + } + for _, item := range accounts { + m, ok := item.(map[string]any) + if !ok { + continue + } + acc := toAccount(m) + id := acc.Identifier() + if id == "" || existing[id] { + continue + } + c.Accounts = append(c.Accounts, acc) + existing[id] = true + importedAccounts++ + } + } + return nil + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts}) +} + +func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { + jsonStr, b64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64}) +} + +func (h *Handler) computeSyncHash() string { + snap := h.Store.Snapshot() + syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}} + accounts := make([]map[string]any, 0, len(snap.Accounts)) + for _, a := range snap.Accounts { + m := map[string]any{} + if a.Email != "" { + m["email"] = a.Email + } + if a.Mobile != "" { + m["mobile"] = a.Mobile + } + if a.Password != "" { + m["password"] = a.Password + } + accounts = append(accounts, m) + } + sort.Slice(accounts, func(i, j int) bool { + ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"]) + aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"]) + return ai < aj + }) + syncable["accounts"] = accounts + b, _ := json.Marshal(syncable) + sum := md5.Sum(b) + return fmt.Sprintf("%x", sum) +} diff --git a/internal/admin/handler_vercel.go b/internal/admin/handler_vercel.go new file mode 100644 index 0000000..189d8cc --- /dev/null +++ b/internal/admin/handler_vercel.go @@ -0,0 +1,197 @@ +package admin + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" +) + +func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + vercelToken, _ := req["vercel_token"].(string) + projectID, _ := req["project_id"].(string) + teamID, _ := req["team_id"].(string) + autoValidate := true + if v, ok := req["auto_validate"].(bool); ok { + autoValidate = v + } + saveCreds := true + if v, ok := req["save_credentials"].(bool); ok { + saveCreds = v + } + usePreconfig := vercelToken == "__USE_PRECONFIG__" || strings.TrimSpace(vercelToken) == "" + if usePreconfig { + vercelToken = strings.TrimSpace(os.Getenv("VERCEL_TOKEN")) + } + if strings.TrimSpace(projectID) == "" { + projectID = strings.TrimSpace(os.Getenv("VERCEL_PROJECT_ID")) + } + if strings.TrimSpace(teamID) == "" { + teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID")) + } + if vercelToken == "" || projectID == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"}) + return + } + validated, failed := 0, []string{} + if autoValidate { + for _, acc := range h.Store.Snapshot().Accounts { + if strings.TrimSpace(acc.Token) != "" { + continue + } + token, err := h.DS.Login(r.Context(), acc) + if err != nil { + failed = append(failed, acc.Identifier()) + } else { + validated++ + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + time.Sleep(500 * time.Millisecond) + } + } + + cfgJSON, _, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON)) + client := &http.Client{Timeout: 30 * time.Second} + params := url.Values{} + if teamID != "" { + params.Set("teamId", teamID) + } + headers := map[string]string{"Authorization": "Bearer " + vercelToken} + envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil) + if err != nil || status != http.StatusOK { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) + return + } + envs, _ := envResp["envs"].([]any) + existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON") + if existingEnvID != "" { + _, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64}) + } else { + _, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}}) + } + if err != nil || (status != http.StatusOK && status != http.StatusCreated) { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) + return + } + savedCreds := []string{} + if saveCreds && !usePreconfig { + creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}} + if teamID != "" { + creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID}) + } + for _, kv := range creds { + id := findEnvID(envs, kv[0]) + if id != "" { + _, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]}) + } else { + _, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}}) + } + if status == http.StatusOK || status == http.StatusCreated { + savedCreds = append(savedCreds, kv[0]) + } + } + } + projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) + manual := true + deployURL := "" + if status == http.StatusOK { + if link, ok := projectResp["link"].(map[string]any); ok { + if linkType, _ := link["type"].(string); linkType == "github" { + repoID := intFrom(link["repoId"]) + ref, _ := link["productionBranch"].(string) + if ref == "" { + ref = "main" + } + depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}}) + if depStatus == http.StatusOK || depStatus == http.StatusCreated { + deployURL, _ = depResp["url"].(string) + manual = false + } + } + } + } + _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) + result := map[string]any{"success": true, "validated_accounts": validated} + if manual { + result["message"] = "配置已同步到 Vercel,请手动触发重新部署" + result["manual_deploy_required"] = true + } else { + result["message"] = "配置已同步,正在重新部署..." + result["deployment_url"] = deployURL + } + if len(failed) > 0 { + result["failed_accounts"] = failed + } + if len(savedCreds) > 0 { + result["saved_credentials"] = savedCreds + } + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + current := h.computeSyncHash() + synced := snap.VercelSyncHash != "" && snap.VercelSyncHash == current + writeJSON(w, http.StatusOK, map[string]any{"synced": synced, "last_sync_time": nilIfZero(snap.VercelSyncTime), "has_synced_before": snap.VercelSyncHash != ""}) +} + +func vercelRequest(ctx context.Context, client *http.Client, method, endpoint string, params url.Values, headers map[string]string, body any) (map[string]any, int, error) { + if len(params) > 0 { + endpoint += "?" + params.Encode() + } + var reader io.Reader + if body != nil { + b, _ := json.Marshal(body) + reader = bytes.NewReader(b) + } + req, err := http.NewRequestWithContext(ctx, method, endpoint, reader) + if err != nil { + return nil, 0, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + parsed := map[string]any{} + _ = json.Unmarshal(b, &parsed) + if len(parsed) == 0 { + parsed["raw"] = string(b) + } + return parsed, resp.StatusCode, nil +} + +func findEnvID(envs []any, key string) string { + for _, item := range envs { + m, ok := item.(map[string]any) + if !ok { + continue + } + if k, _ := m["key"].(string); k == key { + id, _ := m["id"].(string) + return id + } + } + return "" +} diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go new file mode 100644 index 0000000..a1d21b8 --- /dev/null +++ b/internal/admin/helpers.go @@ -0,0 +1,98 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "ds2api/internal/config" +) + +func reverseAccounts(a []config.Account) { + for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { + a[i], a[j] = a[j], a[i] + } +} + +func intFromQuery(r *http.Request, key string, d int) int { + v := r.URL.Query().Get(key) + if v == "" { + return d + } + n, err := strconv.Atoi(v) + if err != nil { + return d + } + return n +} + +func intFrom(v any) int { + switch n := v.(type) { + case float64: + return int(n) + case int: + return n + case int64: + return int(n) + default: + return 0 + } +} + +func nilIfEmpty(s string) any { + if s == "" { + return nil + } + return s +} + +func nilIfZero(v int64) any { + if v == 0 { + return nil + } + return v +} + +func toStringSlice(v any) ([]string, bool) { + arr, ok := v.([]any) + if !ok { + return nil, false + } + out := make([]string, 0, len(arr)) + for _, item := range arr { + out = append(out, strings.TrimSpace(fmt.Sprintf("%v", item))) + } + return out, true +} + +func toAccount(m map[string]any) config.Account { + return config.Account{ + Email: fieldString(m, "email"), + Mobile: fieldString(m, "mobile"), + Password: fieldString(m, "password"), + Token: fieldString(m, "token"), + } +} + +func fieldString(m map[string]any, key string) string { + v, ok := m[key] + if !ok || v == nil { + return "" + } + return strings.TrimSpace(fmt.Sprintf("%v", v)) +} + +func statusOr(v int, d int) int { + if v == 0 { + return d + } + return v +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +}